diff options
Diffstat (limited to 'lib/Transforms/Utils')
27 files changed, 4459 insertions, 2807 deletions
diff --git a/lib/Transforms/Utils/AddDiscriminators.cpp b/lib/Transforms/Utils/AddDiscriminators.cpp index 196ac79aaf29..820544bcebf0 100644 --- a/lib/Transforms/Utils/AddDiscriminators.cpp +++ b/lib/Transforms/Utils/AddDiscriminators.cpp @@ -167,7 +167,7 @@ bool AddDiscriminators::runOnFunction(Function &F) { bool Changed = false; Module *M = F.getParent(); LLVMContext &Ctx = M->getContext(); - DIBuilder Builder(*M); + DIBuilder Builder(*M, /*AllowUnresolved*/ false); // Traverse all the blocks looking for instructions in different // blocks that are at the same file:line location. @@ -193,13 +193,11 @@ bool AddDiscriminators::runOnFunction(Function &F) { // Create a new lexical scope and compute a new discriminator // number for it. StringRef Filename = FirstDIL.getFilename(); - unsigned LineNumber = FirstDIL.getLineNumber(); - unsigned ColumnNumber = FirstDIL.getColumnNumber(); DIScope Scope = FirstDIL.getScope(); DIFile File = Builder.createFile(Filename, Scope.getDirectory()); unsigned Discriminator = FirstDIL.computeNewDiscriminator(Ctx); - DILexicalBlock NewScope = Builder.createLexicalBlock( - Scope, File, LineNumber, ColumnNumber, Discriminator); + DILexicalBlockFile NewScope = + Builder.createLexicalBlockFile(Scope, File, Discriminator); DILocation NewDIL = FirstDIL.copyWithNewScope(Ctx, NewScope); DebugLoc newDebugLoc = DebugLoc::getFromDILocation(NewDIL); diff --git a/lib/Transforms/Utils/BasicBlockUtils.cpp b/lib/Transforms/Utils/BasicBlockUtils.cpp index 602e8ba55107..983f025a1a3a 100644 --- a/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -265,6 +265,18 @@ BasicBlock *llvm::SplitEdge(BasicBlock *BB, BasicBlock *Succ, Pass *P) { return SplitBlock(BB, BB->getTerminator(), P); } +unsigned llvm::SplitAllCriticalEdges(Function &F, Pass *P) { + unsigned NumBroken = 0; + for (Function::iterator I = F.begin(), E = F.end(); I != E; ++I) { + TerminatorInst *TI = I->getTerminator(); + if (TI->getNumSuccessors() > 1 && !isa<IndirectBrInst>(TI)) + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + if (SplitCriticalEdge(TI, i, P)) + ++NumBroken; + } + return NumBroken; +} + /// SplitBlock - Split the specified block at the specified instruction - every /// thing before SplitPt stays in Old and everything starting with SplitPt moves /// to a new block. The two blocks are joined by an unconditional branch and diff --git a/lib/Transforms/Utils/BreakCriticalEdges.cpp b/lib/Transforms/Utils/BreakCriticalEdges.cpp index 80bd51637514..eda22cfc1bab 100644 --- a/lib/Transforms/Utils/BreakCriticalEdges.cpp +++ b/lib/Transforms/Utils/BreakCriticalEdges.cpp @@ -40,7 +40,11 @@ namespace { initializeBreakCriticalEdgesPass(*PassRegistry::getPassRegistry()); } - bool runOnFunction(Function &F) override; + bool runOnFunction(Function &F) override { + unsigned N = SplitAllCriticalEdges(F, this); + NumBroken += N; + return N > 0; + } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addPreserved<DominatorTreeWrapperPass>(); @@ -62,24 +66,6 @@ FunctionPass *llvm::createBreakCriticalEdgesPass() { return new BreakCriticalEdges(); } -// runOnFunction - Loop over all of the edges in the CFG, breaking critical -// edges as they are found. -// -bool BreakCriticalEdges::runOnFunction(Function &F) { - bool Changed = false; - for (Function::iterator I = F.begin(), E = F.end(); I != E; ++I) { - TerminatorInst *TI = I->getTerminator(); - if (TI->getNumSuccessors() > 1 && !isa<IndirectBrInst>(TI)) - for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) - if (SplitCriticalEdge(TI, i, this)) { - ++NumBroken; - Changed = true; - } - } - - return Changed; -} - //===----------------------------------------------------------------------===// // Implementation of the external critical edge manipulation functions //===----------------------------------------------------------------------===// diff --git a/lib/Transforms/Utils/BuildLibCalls.cpp b/lib/Transforms/Utils/BuildLibCalls.cpp index be00b6956199..322485d9e32a 100644 --- a/lib/Transforms/Utils/BuildLibCalls.cpp +++ b/lib/Transforms/Utils/BuildLibCalls.cpp @@ -42,8 +42,7 @@ Value *llvm::EmitStrLen(Value *Ptr, IRBuilder<> &B, const DataLayout *TD, AttributeSet AS[2]; AS[0] = AttributeSet::get(M->getContext(), 1, Attribute::NoCapture); Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AS[1] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - ArrayRef<Attribute::AttrKind>(AVs, 2)); + AS[1] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); LLVMContext &Context = B.GetInsertBlock()->getContext(); Constant *StrLen = M->getOrInsertFunction("strlen", @@ -51,7 +50,7 @@ Value *llvm::EmitStrLen(Value *Ptr, IRBuilder<> &B, const DataLayout *TD, AS), TD->getIntPtrType(Context), B.getInt8PtrTy(), - NULL); + nullptr); CallInst *CI = B.CreateCall(StrLen, CastToCStr(Ptr, B), "strlen"); if (const Function *F = dyn_cast<Function>(StrLen->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -71,8 +70,7 @@ Value *llvm::EmitStrNLen(Value *Ptr, Value *MaxLen, IRBuilder<> &B, AttributeSet AS[2]; AS[0] = AttributeSet::get(M->getContext(), 1, Attribute::NoCapture); Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AS[1] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - ArrayRef<Attribute::AttrKind>(AVs, 2)); + AS[1] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); LLVMContext &Context = B.GetInsertBlock()->getContext(); Constant *StrNLen = M->getOrInsertFunction("strnlen", @@ -81,7 +79,7 @@ Value *llvm::EmitStrNLen(Value *Ptr, Value *MaxLen, IRBuilder<> &B, TD->getIntPtrType(Context), B.getInt8PtrTy(), TD->getIntPtrType(Context), - NULL); + nullptr); CallInst *CI = B.CreateCall2(StrNLen, CastToCStr(Ptr, B), MaxLen, "strnlen"); if (const Function *F = dyn_cast<Function>(StrNLen->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -100,15 +98,14 @@ Value *llvm::EmitStrChr(Value *Ptr, char C, IRBuilder<> &B, Module *M = B.GetInsertBlock()->getParent()->getParent(); Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; AttributeSet AS = - AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - ArrayRef<Attribute::AttrKind>(AVs, 2)); + AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); Type *I8Ptr = B.getInt8PtrTy(); Type *I32Ty = B.getInt32Ty(); Constant *StrChr = M->getOrInsertFunction("strchr", AttributeSet::get(M->getContext(), AS), - I8Ptr, I8Ptr, I32Ty, NULL); + I8Ptr, I8Ptr, I32Ty, nullptr); CallInst *CI = B.CreateCall2(StrChr, CastToCStr(Ptr, B), ConstantInt::get(I32Ty, C), "strchr"); if (const Function *F = dyn_cast<Function>(StrChr->stripPointerCasts())) @@ -128,8 +125,7 @@ Value *llvm::EmitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, AS[0] = AttributeSet::get(M->getContext(), 1, Attribute::NoCapture); AS[1] = AttributeSet::get(M->getContext(), 2, Attribute::NoCapture); Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AS[2] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - ArrayRef<Attribute::AttrKind>(AVs, 2)); + AS[2] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); LLVMContext &Context = B.GetInsertBlock()->getContext(); Value *StrNCmp = M->getOrInsertFunction("strncmp", @@ -138,7 +134,7 @@ Value *llvm::EmitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, B.getInt32Ty(), B.getInt8PtrTy(), B.getInt8PtrTy(), - TD->getIntPtrType(Context), NULL); + TD->getIntPtrType(Context), nullptr); CallInst *CI = B.CreateCall3(StrNCmp, CastToCStr(Ptr1, B), CastToCStr(Ptr2, B), Len, "strncmp"); @@ -164,7 +160,7 @@ Value *llvm::EmitStrCpy(Value *Dst, Value *Src, IRBuilder<> &B, Type *I8Ptr = B.getInt8PtrTy(); Value *StrCpy = M->getOrInsertFunction(Name, AttributeSet::get(M->getContext(), AS), - I8Ptr, I8Ptr, I8Ptr, NULL); + I8Ptr, I8Ptr, I8Ptr, nullptr); CallInst *CI = B.CreateCall2(StrCpy, CastToCStr(Dst, B), CastToCStr(Src, B), Name); if (const Function *F = dyn_cast<Function>(StrCpy->stripPointerCasts())) @@ -190,7 +186,7 @@ Value *llvm::EmitStrNCpy(Value *Dst, Value *Src, Value *Len, AttributeSet::get(M->getContext(), AS), I8Ptr, I8Ptr, I8Ptr, - Len->getType(), NULL); + Len->getType(), nullptr); CallInst *CI = B.CreateCall3(StrNCpy, CastToCStr(Dst, B), CastToCStr(Src, B), Len, "strncpy"); if (const Function *F = dyn_cast<Function>(StrNCpy->stripPointerCasts())) @@ -218,7 +214,7 @@ Value *llvm::EmitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize, B.getInt8PtrTy(), B.getInt8PtrTy(), TD->getIntPtrType(Context), - TD->getIntPtrType(Context), NULL); + TD->getIntPtrType(Context), nullptr); Dst = CastToCStr(Dst, B); Src = CastToCStr(Src, B); CallInst *CI = B.CreateCall4(MemCpy, Dst, Src, Len, ObjSize); @@ -238,8 +234,7 @@ Value *llvm::EmitMemChr(Value *Ptr, Value *Val, Module *M = B.GetInsertBlock()->getParent()->getParent(); AttributeSet AS; Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AS = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - ArrayRef<Attribute::AttrKind>(AVs, 2)); + AS = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); LLVMContext &Context = B.GetInsertBlock()->getContext(); Value *MemChr = M->getOrInsertFunction("memchr", AttributeSet::get(M->getContext(), AS), @@ -247,7 +242,7 @@ Value *llvm::EmitMemChr(Value *Ptr, Value *Val, B.getInt8PtrTy(), B.getInt32Ty(), TD->getIntPtrType(Context), - NULL); + nullptr); CallInst *CI = B.CreateCall3(MemChr, CastToCStr(Ptr, B), Val, Len, "memchr"); if (const Function *F = dyn_cast<Function>(MemChr->stripPointerCasts())) @@ -268,8 +263,7 @@ Value *llvm::EmitMemCmp(Value *Ptr1, Value *Ptr2, AS[0] = AttributeSet::get(M->getContext(), 1, Attribute::NoCapture); AS[1] = AttributeSet::get(M->getContext(), 2, Attribute::NoCapture); Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AS[2] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - ArrayRef<Attribute::AttrKind>(AVs, 2)); + AS[2] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); LLVMContext &Context = B.GetInsertBlock()->getContext(); Value *MemCmp = M->getOrInsertFunction("memcmp", @@ -277,7 +271,7 @@ Value *llvm::EmitMemCmp(Value *Ptr1, Value *Ptr2, B.getInt32Ty(), B.getInt8PtrTy(), B.getInt8PtrTy(), - TD->getIntPtrType(Context), NULL); + TD->getIntPtrType(Context), nullptr); CallInst *CI = B.CreateCall3(MemCmp, CastToCStr(Ptr1, B), CastToCStr(Ptr2, B), Len, "memcmp"); @@ -313,7 +307,7 @@ Value *llvm::EmitUnaryFloatFnCall(Value *Op, StringRef Name, IRBuilder<> &B, Module *M = B.GetInsertBlock()->getParent()->getParent(); Value *Callee = M->getOrInsertFunction(Name, Op->getType(), - Op->getType(), NULL); + Op->getType(), nullptr); CallInst *CI = B.CreateCall(Callee, Op, Name); CI->setAttributes(Attrs); if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) @@ -334,7 +328,7 @@ Value *llvm::EmitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name, Module *M = B.GetInsertBlock()->getParent()->getParent(); Value *Callee = M->getOrInsertFunction(Name, Op1->getType(), - Op1->getType(), Op2->getType(), NULL); + Op1->getType(), Op2->getType(), nullptr); CallInst *CI = B.CreateCall2(Callee, Op1, Op2, Name); CI->setAttributes(Attrs); if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) @@ -352,7 +346,7 @@ Value *llvm::EmitPutChar(Value *Char, IRBuilder<> &B, const DataLayout *TD, Module *M = B.GetInsertBlock()->getParent()->getParent(); Value *PutChar = M->getOrInsertFunction("putchar", B.getInt32Ty(), - B.getInt32Ty(), NULL); + B.getInt32Ty(), nullptr); CallInst *CI = B.CreateCall(PutChar, B.CreateIntCast(Char, B.getInt32Ty(), @@ -382,7 +376,7 @@ Value *llvm::EmitPutS(Value *Str, IRBuilder<> &B, const DataLayout *TD, AttributeSet::get(M->getContext(), AS), B.getInt32Ty(), B.getInt8PtrTy(), - NULL); + nullptr); CallInst *CI = B.CreateCall(PutS, CastToCStr(Str, B), "puts"); if (const Function *F = dyn_cast<Function>(PutS->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -407,12 +401,12 @@ Value *llvm::EmitFPutC(Value *Char, Value *File, IRBuilder<> &B, AttributeSet::get(M->getContext(), AS), B.getInt32Ty(), B.getInt32Ty(), File->getType(), - NULL); + nullptr); else F = M->getOrInsertFunction("fputc", B.getInt32Ty(), B.getInt32Ty(), - File->getType(), NULL); + File->getType(), nullptr); Char = B.CreateIntCast(Char, B.getInt32Ty(), /*isSigned*/true, "chari"); CallInst *CI = B.CreateCall2(F, Char, File, "fputc"); @@ -442,11 +436,11 @@ Value *llvm::EmitFPutS(Value *Str, Value *File, IRBuilder<> &B, AttributeSet::get(M->getContext(), AS), B.getInt32Ty(), B.getInt8PtrTy(), - File->getType(), NULL); + File->getType(), nullptr); else F = M->getOrInsertFunction(FPutsName, B.getInt32Ty(), B.getInt8PtrTy(), - File->getType(), NULL); + File->getType(), nullptr); CallInst *CI = B.CreateCall2(F, CastToCStr(Str, B), File, "fputs"); if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) @@ -478,13 +472,13 @@ Value *llvm::EmitFWrite(Value *Ptr, Value *Size, Value *File, B.getInt8PtrTy(), TD->getIntPtrType(Context), TD->getIntPtrType(Context), - File->getType(), NULL); + File->getType(), nullptr); else F = M->getOrInsertFunction(FWriteName, TD->getIntPtrType(Context), B.getInt8PtrTy(), TD->getIntPtrType(Context), TD->getIntPtrType(Context), - File->getType(), NULL); + File->getType(), nullptr); CallInst *CI = B.CreateCall4(F, CastToCStr(Ptr, B), Size, ConstantInt::get(TD->getIntPtrType(Context), 1), File); @@ -492,135 +486,3 @@ Value *llvm::EmitFWrite(Value *Ptr, Value *Size, Value *File, CI->setCallingConv(Fn->getCallingConv()); return CI; } - -SimplifyFortifiedLibCalls::~SimplifyFortifiedLibCalls() { } - -bool SimplifyFortifiedLibCalls::fold(CallInst *CI, const DataLayout *TD, - const TargetLibraryInfo *TLI) { - // We really need DataLayout for later. - if (!TD) return false; - - this->CI = CI; - Function *Callee = CI->getCalledFunction(); - StringRef Name = Callee->getName(); - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); - IRBuilder<> B(CI); - - if (Name == "__memcpy_chk") { - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - FT->getParamType(2) != TD->getIntPtrType(Context) || - FT->getParamType(3) != TD->getIntPtrType(Context)) - return false; - - if (isFoldable(3, 2, false)) { - B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); - replaceCall(CI->getArgOperand(0)); - return true; - } - return false; - } - - // Should be similar to memcpy. - if (Name == "__mempcpy_chk") { - return false; - } - - if (Name == "__memmove_chk") { - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - FT->getParamType(2) != TD->getIntPtrType(Context) || - FT->getParamType(3) != TD->getIntPtrType(Context)) - return false; - - if (isFoldable(3, 2, false)) { - B.CreateMemMove(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); - replaceCall(CI->getArgOperand(0)); - return true; - } - return false; - } - - if (Name == "__memset_chk") { - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isIntegerTy() || - FT->getParamType(2) != TD->getIntPtrType(Context) || - FT->getParamType(3) != TD->getIntPtrType(Context)) - return false; - - if (isFoldable(3, 2, false)) { - Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), - false); - B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); - replaceCall(CI->getArgOperand(0)); - return true; - } - return false; - } - - if (Name == "__strcpy_chk" || Name == "__stpcpy_chk") { - // Check if this has the right signature. - if (FT->getNumParams() != 3 || - FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != Type::getInt8PtrTy(Context) || - FT->getParamType(2) != TD->getIntPtrType(Context)) - return 0; - - - // If a) we don't have any length information, or b) we know this will - // fit then just lower to a plain st[rp]cpy. Otherwise we'll keep our - // st[rp]cpy_chk call which may fail at runtime if the size is too long. - // TODO: It might be nice to get a maximum length out of the possible - // string lengths for varying. - if (isFoldable(2, 1, true)) { - Value *Ret = EmitStrCpy(CI->getArgOperand(0), CI->getArgOperand(1), B, TD, - TLI, Name.substr(2, 6)); - if (!Ret) - return false; - replaceCall(Ret); - return true; - } - return false; - } - - if (Name == "__strncpy_chk" || Name == "__stpncpy_chk") { - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != Type::getInt8PtrTy(Context) || - !FT->getParamType(2)->isIntegerTy() || - FT->getParamType(3) != TD->getIntPtrType(Context)) - return false; - - if (isFoldable(3, 2, false)) { - Value *Ret = EmitStrNCpy(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), B, TD, TLI, - Name.substr(2, 7)); - if (!Ret) - return false; - replaceCall(Ret); - return true; - } - return false; - } - - if (Name == "__strcat_chk") { - return false; - } - - if (Name == "__strncat_chk") { - return false; - } - - return false; -} diff --git a/lib/Transforms/Utils/CMakeLists.txt b/lib/Transforms/Utils/CMakeLists.txt index fcf548f97c5d..6ce22b101825 100644 --- a/lib/Transforms/Utils/CMakeLists.txt +++ b/lib/Transforms/Utils/CMakeLists.txt @@ -1,16 +1,17 @@ add_llvm_library(LLVMTransformUtils - AddDiscriminators.cpp ASanStackFrameLayout.cpp + AddDiscriminators.cpp BasicBlockUtils.cpp BreakCriticalEdges.cpp BuildLibCalls.cpp BypassSlowDivision.cpp - CtorUtils.cpp CloneFunction.cpp CloneModule.cpp CmpInstAnalysis.cpp CodeExtractor.cpp + CtorUtils.cpp DemoteRegToStack.cpp + FlattenCFG.cpp GlobalStatus.cpp InlineFunction.cpp InstructionNamer.cpp @@ -29,10 +30,10 @@ add_llvm_library(LLVMTransformUtils PromoteMemoryToRegister.cpp SSAUpdater.cpp SimplifyCFG.cpp - FlattenCFG.cpp SimplifyIndVar.cpp SimplifyInstructions.cpp SimplifyLibCalls.cpp + SymbolRewriter.cpp UnifyFunctionExitNodes.cpp Utils.cpp ValueMapper.cpp diff --git a/lib/Transforms/Utils/CloneFunction.cpp b/lib/Transforms/Utils/CloneFunction.cpp index 5c8f20d5f884..96a763fac93b 100644 --- a/lib/Transforms/Utils/CloneFunction.cpp +++ b/lib/Transforms/Utils/CloneFunction.cpp @@ -164,14 +164,13 @@ static MDNode* FindSubprogram(const Function *F, DebugInfoFinder &Finder) { // Add an operand to an existing MDNode. The new operand will be added at the // back of the operand list. -static void AddOperand(MDNode *Node, Value *Operand) { - SmallVector<Value*, 16> Operands; - for (unsigned i = 0; i < Node->getNumOperands(); i++) { - Operands.push_back(Node->getOperand(i)); - } - Operands.push_back(Operand); - MDNode *NewNode = MDNode::get(Node->getContext(), Operands); - Node->replaceAllUsesWith(NewNode); +static void AddOperand(DICompileUnit CU, DIArray SPs, Metadata *NewSP) { + SmallVector<Metadata *, 16> NewSPs; + NewSPs.reserve(SPs->getNumOperands() + 1); + for (unsigned I = 0, E = SPs->getNumOperands(); I != E; ++I) + NewSPs.push_back(SPs->getOperand(I)); + NewSPs.push_back(NewSP); + CU.replaceSubprograms(DIArray(MDNode::get(CU->getContext(), NewSPs))); } // Clone the module-level debug info associated with OldFunc. The cloned data @@ -187,7 +186,7 @@ static void CloneDebugInfoMetadata(Function *NewFunc, const Function *OldFunc, // Ensure that OldFunc appears in the map. // (if it's already there it must point to NewFunc anyway) VMap[OldFunc] = NewFunc; - DISubprogram NewSubprogram(MapValue(OldSubprogramMDNode, VMap)); + DISubprogram NewSubprogram(MapMetadata(OldSubprogramMDNode, VMap)); for (DICompileUnit CU : Finder.compile_units()) { DIArray Subprograms(CU.getSubprograms()); @@ -196,7 +195,8 @@ static void CloneDebugInfoMetadata(Function *NewFunc, const Function *OldFunc, // also contain the new one. for (unsigned i = 0; i < Subprograms.getNumElements(); i++) { if ((MDNode*)Subprograms.getElement(i) == OldSubprogramMDNode) { - AddOperand(Subprograms, NewSubprogram); + AddOperand(CU, Subprograms, NewSubprogram); + break; } } } diff --git a/lib/Transforms/Utils/CloneModule.cpp b/lib/Transforms/Utils/CloneModule.cpp index 3f75b3e677ee..fae9ff5bce0f 100644 --- a/lib/Transforms/Utils/CloneModule.cpp +++ b/lib/Transforms/Utils/CloneModule.cpp @@ -17,6 +17,7 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" #include "llvm/Transforms/Utils/ValueMapper.h" +#include "llvm-c/Core.h" using namespace llvm; /// CloneModule - Return an exact copy of the specified module. This is not as @@ -108,7 +109,7 @@ Module *llvm::CloneModule(const Module *M, ValueToValueMapTy &VMap) { I != E; ++I) { GlobalAlias *GA = cast<GlobalAlias>(VMap[I]); if (const Constant *C = I->getAliasee()) - GA->setAliasee(cast<GlobalObject>(MapValue(C, VMap))); + GA->setAliasee(MapValue(C, VMap)); } // And named metadata.... @@ -117,8 +118,16 @@ Module *llvm::CloneModule(const Module *M, ValueToValueMapTy &VMap) { const NamedMDNode &NMD = *I; NamedMDNode *NewNMD = New->getOrInsertNamedMetadata(NMD.getName()); for (unsigned i = 0, e = NMD.getNumOperands(); i != e; ++i) - NewNMD->addOperand(MapValue(NMD.getOperand(i), VMap)); + NewNMD->addOperand(MapMetadata(NMD.getOperand(i), VMap)); } return New; } + +extern "C" { + +LLVMModuleRef LLVMCloneModule(LLVMModuleRef M) { + return wrap(CloneModule(unwrap(M))); +} + +} diff --git a/lib/Transforms/Utils/CtorUtils.cpp b/lib/Transforms/Utils/CtorUtils.cpp index a3594248de0c..26875e837b8b 100644 --- a/lib/Transforms/Utils/CtorUtils.cpp +++ b/lib/Transforms/Utils/CtorUtils.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/BitVector.h" #include "llvm/Transforms/Utils/CtorUtils.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" @@ -24,41 +25,22 @@ namespace llvm { namespace { -/// Given a specified llvm.global_ctors list, install the -/// specified array. -void installGlobalCtors(GlobalVariable *GCL, - const std::vector<Function *> &Ctors) { - // If we made a change, reassemble the initializer list. - Constant *CSVals[3]; - - StructType *StructTy = - cast<StructType>(GCL->getType()->getElementType()->getArrayElementType()); - - // Create the new init list. - std::vector<Constant *> CAList; - for (Function *F : Ctors) { - Type *Int32Ty = Type::getInt32Ty(GCL->getContext()); - if (F) { - CSVals[0] = ConstantInt::get(Int32Ty, 65535); - CSVals[1] = F; - } else { - CSVals[0] = ConstantInt::get(Int32Ty, 0x7fffffff); - CSVals[1] = Constant::getNullValue(StructTy->getElementType(1)); - } - // FIXME: Only allow the 3-field form in LLVM 4.0. - size_t NumElts = StructTy->getNumElements(); - if (NumElts > 2) - CSVals[2] = Constant::getNullValue(StructTy->getElementType(2)); - CAList.push_back( - ConstantStruct::get(StructTy, makeArrayRef(CSVals, NumElts))); - } - - // Create the array initializer. - Constant *CA = - ConstantArray::get(ArrayType::get(StructTy, CAList.size()), CAList); +/// Given a specified llvm.global_ctors list, remove the listed elements. +void removeGlobalCtors(GlobalVariable *GCL, const BitVector &CtorsToRemove) { + // Filter out the initializer elements to remove. + ConstantArray *OldCA = cast<ConstantArray>(GCL->getInitializer()); + SmallVector<Constant *, 10> CAList; + for (unsigned I = 0, E = OldCA->getNumOperands(); I < E; ++I) + if (!CtorsToRemove.test(I)) + CAList.push_back(OldCA->getOperand(I)); + + // Create the new array initializer. + ArrayType *ATy = + ArrayType::get(OldCA->getType()->getElementType(), CAList.size()); + Constant *CA = ConstantArray::get(ATy, CAList); // If we didn't change the number of elements, don't create a new GV. - if (CA->getType() == GCL->getInitializer()->getType()) { + if (CA->getType() == OldCA->getType()) { GCL->setInitializer(CA); return; } @@ -82,7 +64,7 @@ void installGlobalCtors(GlobalVariable *GCL, /// Given a llvm.global_ctors list that we can understand, /// return a list of the functions and null terminator as a vector. -std::vector<Function*> parseGlobalCtors(GlobalVariable *GV) { +std::vector<Function *> parseGlobalCtors(GlobalVariable *GV) { if (GV->getInitializer()->isNullValue()) return std::vector<Function *>(); ConstantArray *CA = cast<ConstantArray>(GV->getInitializer()); @@ -147,17 +129,15 @@ bool optimizeGlobalCtorsList(Module &M, bool MadeChange = false; // Loop over global ctors, optimizing them when we can. - for (unsigned i = 0; i != Ctors.size(); ++i) { + unsigned NumCtors = Ctors.size(); + BitVector CtorsToRemove(NumCtors); + for (unsigned i = 0; i != Ctors.size() && NumCtors > 0; ++i) { Function *F = Ctors[i]; // Found a null terminator in the middle of the list, prune off the rest of // the list. - if (!F) { - if (i != Ctors.size() - 1) { - Ctors.resize(i + 1); - MadeChange = true; - } - break; - } + if (!F) + continue; + DEBUG(dbgs() << "Optimizing Global Constructor: " << *F << "\n"); // We cannot simplify external ctor functions. @@ -166,9 +146,10 @@ bool optimizeGlobalCtorsList(Module &M, // If we can evaluate the ctor at compile time, do. if (ShouldRemove(F)) { - Ctors.erase(Ctors.begin() + i); + Ctors[i] = nullptr; + CtorsToRemove.set(i); + NumCtors--; MadeChange = true; - --i; continue; } } @@ -176,7 +157,7 @@ bool optimizeGlobalCtorsList(Module &M, if (!MadeChange) return false; - installGlobalCtors(GlobalCtors, Ctors); + removeGlobalCtors(GlobalCtors, CtorsToRemove); return true; } diff --git a/lib/Transforms/Utils/FlattenCFG.cpp b/lib/Transforms/Utils/FlattenCFG.cpp index 51ead40c916e..4eb3e3dd17d2 100644 --- a/lib/Transforms/Utils/FlattenCFG.cpp +++ b/lib/Transforms/Utils/FlattenCFG.cpp @@ -238,9 +238,13 @@ bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder, // Do branch inversion. BasicBlock *CurrBlock = LastCondBlock; bool EverChanged = false; - while (1) { + for (;CurrBlock != FirstCondBlock; + CurrBlock = CurrBlock->getSinglePredecessor()) { BranchInst *BI = dyn_cast<BranchInst>(CurrBlock->getTerminator()); CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition()); + if (!CI) + continue; + CmpInst::Predicate Predicate = CI->getPredicate(); // Canonicalize icmp_ne -> icmp_eq, fcmp_one -> fcmp_oeq if ((Predicate == CmpInst::ICMP_NE) || (Predicate == CmpInst::FCMP_ONE)) { @@ -248,9 +252,6 @@ bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder, BI->swapSuccessors(); EverChanged = true; } - if (CurrBlock == FirstCondBlock) - break; - CurrBlock = CurrBlock->getSinglePredecessor(); } return EverChanged; } diff --git a/lib/Transforms/Utils/GlobalStatus.cpp b/lib/Transforms/Utils/GlobalStatus.cpp index 12057e4b929c..52e2d59557fa 100644 --- a/lib/Transforms/Utils/GlobalStatus.cpp +++ b/lib/Transforms/Utils/GlobalStatus.cpp @@ -35,6 +35,9 @@ bool llvm::isSafeToDestroyConstant(const Constant *C) { if (isa<GlobalValue>(C)) return false; + if (isa<ConstantInt>(C) || isa<ConstantFP>(C)) + return false; + for (const User *U : C->users()) if (const Constant *CU = dyn_cast<Constant>(U)) { if (!isSafeToDestroyConstant(CU)) @@ -45,7 +48,7 @@ bool llvm::isSafeToDestroyConstant(const Constant *C) { } static bool analyzeGlobalAux(const Value *V, GlobalStatus &GS, - SmallPtrSet<const PHINode *, 16> &PhiUsers) { + SmallPtrSetImpl<const PHINode *> &PhiUsers) { for (const Use &U : V->uses()) { const User *UR = U.getUser(); if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(UR)) { @@ -130,7 +133,7 @@ static bool analyzeGlobalAux(const Value *V, GlobalStatus &GS, } else if (const PHINode *PN = dyn_cast<PHINode>(I)) { // PHI nodes we can check just like select or GEP instructions, but we // have to be careful about infinite recursion. - if (PhiUsers.insert(PN)) // Not already visited. + if (PhiUsers.insert(PN).second) // Not already visited. if (analyzeGlobalAux(I, GS, PhiUsers)) return true; } else if (isa<CmpInst>(I)) { diff --git a/lib/Transforms/Utils/InlineFunction.cpp b/lib/Transforms/Utils/InlineFunction.cpp index f0a9f2b1fcb3..2a86eb598d4d 100644 --- a/lib/Transforms/Utils/InlineFunction.cpp +++ b/lib/Transforms/Utils/InlineFunction.cpp @@ -13,10 +13,16 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CallGraph.h" +#include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/CFG.h" @@ -24,14 +30,28 @@ #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Support/CommandLine.h" +#include <algorithm> using namespace llvm; +static cl::opt<bool> +EnableNoAliasConversion("enable-noalias-to-md-conversion", cl::init(true), + cl::Hidden, + cl::desc("Convert noalias attributes to metadata during inlining.")); + +static cl::opt<bool> +PreserveAlignmentAssumptions("preserve-alignment-assumptions-during-inlining", + cl::init(true), cl::Hidden, + cl::desc("Convert align attributes to assumptions during inlining.")); + bool llvm::InlineFunction(CallInst *CI, InlineFunctionInfo &IFI, bool InsertLifetime) { return InlineFunction(CallSite(CI), IFI, InsertLifetime); @@ -84,7 +104,7 @@ namespace { /// split the landing pad block after the landingpad instruction and jump /// to there. void forwardResume(ResumeInst *RI, - SmallPtrSet<LandingPadInst*, 16> &InlinedLPads); + SmallPtrSetImpl<LandingPadInst*> &InlinedLPads); /// addIncomingPHIValuesFor - Add incoming-PHI values to the unwind /// destination block for the given basic block, using the values for the @@ -143,7 +163,7 @@ BasicBlock *InvokeInliningInfo::getInnerResumeDest() { /// branch. When there is more than one predecessor, we need to split the /// landing pad block after the landingpad instruction and jump to there. void InvokeInliningInfo::forwardResume(ResumeInst *RI, - SmallPtrSet<LandingPadInst*, 16> &InlinedLPads) { + SmallPtrSetImpl<LandingPadInst*> &InlinedLPads) { BasicBlock *Dest = getInnerResumeDest(); BasicBlock *Src = RI->getParent(); @@ -233,9 +253,7 @@ static void HandleInlinedInvoke(InvokeInst *II, BasicBlock *FirstNewBlock, // Append the clauses from the outer landing pad instruction into the inlined // landing pad instructions. LandingPadInst *OuterLPad = Invoke.getLandingPadInst(); - for (SmallPtrSet<LandingPadInst*, 16>::iterator I = InlinedLPads.begin(), - E = InlinedLPads.end(); I != E; ++I) { - LandingPadInst *InlinedLPad = *I; + for (LandingPadInst *InlinedLPad : InlinedLPads) { unsigned OuterNum = OuterLPad->getNumClauses(); InlinedLPad->reserveClauses(OuterNum); for (unsigned OuterIdx = 0; OuterIdx != OuterNum; ++OuterIdx) @@ -260,6 +278,387 @@ static void HandleInlinedInvoke(InvokeInst *II, BasicBlock *FirstNewBlock, InvokeDest->removePredecessor(II->getParent()); } +/// CloneAliasScopeMetadata - When inlining a function that contains noalias +/// scope metadata, this metadata needs to be cloned so that the inlined blocks +/// have different "unqiue scopes" at every call site. Were this not done, then +/// aliasing scopes from a function inlined into a caller multiple times could +/// not be differentiated (and this would lead to miscompiles because the +/// non-aliasing property communicated by the metadata could have +/// call-site-specific control dependencies). +static void CloneAliasScopeMetadata(CallSite CS, ValueToValueMapTy &VMap) { + const Function *CalledFunc = CS.getCalledFunction(); + SetVector<const MDNode *> MD; + + // Note: We could only clone the metadata if it is already used in the + // caller. I'm omitting that check here because it might confuse + // inter-procedural alias analysis passes. We can revisit this if it becomes + // an efficiency or overhead problem. + + for (Function::const_iterator I = CalledFunc->begin(), IE = CalledFunc->end(); + I != IE; ++I) + for (BasicBlock::const_iterator J = I->begin(), JE = I->end(); J != JE; ++J) { + if (const MDNode *M = J->getMetadata(LLVMContext::MD_alias_scope)) + MD.insert(M); + if (const MDNode *M = J->getMetadata(LLVMContext::MD_noalias)) + MD.insert(M); + } + + if (MD.empty()) + return; + + // Walk the existing metadata, adding the complete (perhaps cyclic) chain to + // the set. + SmallVector<const Metadata *, 16> Queue(MD.begin(), MD.end()); + while (!Queue.empty()) { + const MDNode *M = cast<MDNode>(Queue.pop_back_val()); + for (unsigned i = 0, ie = M->getNumOperands(); i != ie; ++i) + if (const MDNode *M1 = dyn_cast<MDNode>(M->getOperand(i))) + if (MD.insert(M1)) + Queue.push_back(M1); + } + + // Now we have a complete set of all metadata in the chains used to specify + // the noalias scopes and the lists of those scopes. + SmallVector<MDNode *, 16> DummyNodes; + DenseMap<const MDNode *, TrackingMDNodeRef> MDMap; + for (SetVector<const MDNode *>::iterator I = MD.begin(), IE = MD.end(); + I != IE; ++I) { + MDNode *Dummy = MDNode::getTemporary(CalledFunc->getContext(), None); + DummyNodes.push_back(Dummy); + MDMap[*I].reset(Dummy); + } + + // Create new metadata nodes to replace the dummy nodes, replacing old + // metadata references with either a dummy node or an already-created new + // node. + for (SetVector<const MDNode *>::iterator I = MD.begin(), IE = MD.end(); + I != IE; ++I) { + SmallVector<Metadata *, 4> NewOps; + for (unsigned i = 0, ie = (*I)->getNumOperands(); i != ie; ++i) { + const Metadata *V = (*I)->getOperand(i); + if (const MDNode *M = dyn_cast<MDNode>(V)) + NewOps.push_back(MDMap[M]); + else + NewOps.push_back(const_cast<Metadata *>(V)); + } + + MDNode *NewM = MDNode::get(CalledFunc->getContext(), NewOps); + MDNodeFwdDecl *TempM = cast<MDNodeFwdDecl>(MDMap[*I]); + + TempM->replaceAllUsesWith(NewM); + } + + // Now replace the metadata in the new inlined instructions with the + // repacements from the map. + for (ValueToValueMapTy::iterator VMI = VMap.begin(), VMIE = VMap.end(); + VMI != VMIE; ++VMI) { + if (!VMI->second) + continue; + + Instruction *NI = dyn_cast<Instruction>(VMI->second); + if (!NI) + continue; + + if (MDNode *M = NI->getMetadata(LLVMContext::MD_alias_scope)) { + MDNode *NewMD = MDMap[M]; + // If the call site also had alias scope metadata (a list of scopes to + // which instructions inside it might belong), propagate those scopes to + // the inlined instructions. + if (MDNode *CSM = + CS.getInstruction()->getMetadata(LLVMContext::MD_alias_scope)) + NewMD = MDNode::concatenate(NewMD, CSM); + NI->setMetadata(LLVMContext::MD_alias_scope, NewMD); + } else if (NI->mayReadOrWriteMemory()) { + if (MDNode *M = + CS.getInstruction()->getMetadata(LLVMContext::MD_alias_scope)) + NI->setMetadata(LLVMContext::MD_alias_scope, M); + } + + if (MDNode *M = NI->getMetadata(LLVMContext::MD_noalias)) { + MDNode *NewMD = MDMap[M]; + // If the call site also had noalias metadata (a list of scopes with + // which instructions inside it don't alias), propagate those scopes to + // the inlined instructions. + if (MDNode *CSM = + CS.getInstruction()->getMetadata(LLVMContext::MD_noalias)) + NewMD = MDNode::concatenate(NewMD, CSM); + NI->setMetadata(LLVMContext::MD_noalias, NewMD); + } else if (NI->mayReadOrWriteMemory()) { + if (MDNode *M = CS.getInstruction()->getMetadata(LLVMContext::MD_noalias)) + NI->setMetadata(LLVMContext::MD_noalias, M); + } + } + + // Now that everything has been replaced, delete the dummy nodes. + for (unsigned i = 0, ie = DummyNodes.size(); i != ie; ++i) + MDNode::deleteTemporary(DummyNodes[i]); +} + +/// AddAliasScopeMetadata - If the inlined function has noalias arguments, then +/// add new alias scopes for each noalias argument, tag the mapped noalias +/// parameters with noalias metadata specifying the new scope, and tag all +/// non-derived loads, stores and memory intrinsics with the new alias scopes. +static void AddAliasScopeMetadata(CallSite CS, ValueToValueMapTy &VMap, + const DataLayout *DL, AliasAnalysis *AA) { + if (!EnableNoAliasConversion) + return; + + const Function *CalledFunc = CS.getCalledFunction(); + SmallVector<const Argument *, 4> NoAliasArgs; + + for (Function::const_arg_iterator I = CalledFunc->arg_begin(), + E = CalledFunc->arg_end(); I != E; ++I) { + if (I->hasNoAliasAttr() && !I->hasNUses(0)) + NoAliasArgs.push_back(I); + } + + if (NoAliasArgs.empty()) + return; + + // To do a good job, if a noalias variable is captured, we need to know if + // the capture point dominates the particular use we're considering. + DominatorTree DT; + DT.recalculate(const_cast<Function&>(*CalledFunc)); + + // noalias indicates that pointer values based on the argument do not alias + // pointer values which are not based on it. So we add a new "scope" for each + // noalias function argument. Accesses using pointers based on that argument + // become part of that alias scope, accesses using pointers not based on that + // argument are tagged as noalias with that scope. + + DenseMap<const Argument *, MDNode *> NewScopes; + MDBuilder MDB(CalledFunc->getContext()); + + // Create a new scope domain for this function. + MDNode *NewDomain = + MDB.createAnonymousAliasScopeDomain(CalledFunc->getName()); + for (unsigned i = 0, e = NoAliasArgs.size(); i != e; ++i) { + const Argument *A = NoAliasArgs[i]; + + std::string Name = CalledFunc->getName(); + if (A->hasName()) { + Name += ": %"; + Name += A->getName(); + } else { + Name += ": argument "; + Name += utostr(i); + } + + // Note: We always create a new anonymous root here. This is true regardless + // of the linkage of the callee because the aliasing "scope" is not just a + // property of the callee, but also all control dependencies in the caller. + MDNode *NewScope = MDB.createAnonymousAliasScope(NewDomain, Name); + NewScopes.insert(std::make_pair(A, NewScope)); + } + + // Iterate over all new instructions in the map; for all memory-access + // instructions, add the alias scope metadata. + for (ValueToValueMapTy::iterator VMI = VMap.begin(), VMIE = VMap.end(); + VMI != VMIE; ++VMI) { + if (const Instruction *I = dyn_cast<Instruction>(VMI->first)) { + if (!VMI->second) + continue; + + Instruction *NI = dyn_cast<Instruction>(VMI->second); + if (!NI) + continue; + + bool IsArgMemOnlyCall = false, IsFuncCall = false; + SmallVector<const Value *, 2> PtrArgs; + + if (const LoadInst *LI = dyn_cast<LoadInst>(I)) + PtrArgs.push_back(LI->getPointerOperand()); + else if (const StoreInst *SI = dyn_cast<StoreInst>(I)) + PtrArgs.push_back(SI->getPointerOperand()); + else if (const VAArgInst *VAAI = dyn_cast<VAArgInst>(I)) + PtrArgs.push_back(VAAI->getPointerOperand()); + else if (const AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(I)) + PtrArgs.push_back(CXI->getPointerOperand()); + else if (const AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) + PtrArgs.push_back(RMWI->getPointerOperand()); + else if (ImmutableCallSite ICS = ImmutableCallSite(I)) { + // If we know that the call does not access memory, then we'll still + // know that about the inlined clone of this call site, and we don't + // need to add metadata. + if (ICS.doesNotAccessMemory()) + continue; + + IsFuncCall = true; + if (AA) { + AliasAnalysis::ModRefBehavior MRB = AA->getModRefBehavior(ICS); + if (MRB == AliasAnalysis::OnlyAccessesArgumentPointees || + MRB == AliasAnalysis::OnlyReadsArgumentPointees) + IsArgMemOnlyCall = true; + } + + for (ImmutableCallSite::arg_iterator AI = ICS.arg_begin(), + AE = ICS.arg_end(); AI != AE; ++AI) { + // We need to check the underlying objects of all arguments, not just + // the pointer arguments, because we might be passing pointers as + // integers, etc. + // However, if we know that the call only accesses pointer arguments, + // then we only need to check the pointer arguments. + if (IsArgMemOnlyCall && !(*AI)->getType()->isPointerTy()) + continue; + + PtrArgs.push_back(*AI); + } + } + + // If we found no pointers, then this instruction is not suitable for + // pairing with an instruction to receive aliasing metadata. + // However, if this is a call, this we might just alias with none of the + // noalias arguments. + if (PtrArgs.empty() && !IsFuncCall) + continue; + + // It is possible that there is only one underlying object, but you + // need to go through several PHIs to see it, and thus could be + // repeated in the Objects list. + SmallPtrSet<const Value *, 4> ObjSet; + SmallVector<Metadata *, 4> Scopes, NoAliases; + + SmallSetVector<const Argument *, 4> NAPtrArgs; + for (unsigned i = 0, ie = PtrArgs.size(); i != ie; ++i) { + SmallVector<Value *, 4> Objects; + GetUnderlyingObjects(const_cast<Value*>(PtrArgs[i]), + Objects, DL, /* MaxLookup = */ 0); + + for (Value *O : Objects) + ObjSet.insert(O); + } + + // Figure out if we're derived from anything that is not a noalias + // argument. + bool CanDeriveViaCapture = false, UsesAliasingPtr = false; + for (const Value *V : ObjSet) { + // Is this value a constant that cannot be derived from any pointer + // value (we need to exclude constant expressions, for example, that + // are formed from arithmetic on global symbols). + bool IsNonPtrConst = isa<ConstantInt>(V) || isa<ConstantFP>(V) || + isa<ConstantPointerNull>(V) || + isa<ConstantDataVector>(V) || isa<UndefValue>(V); + if (IsNonPtrConst) + continue; + + // If this is anything other than a noalias argument, then we cannot + // completely describe the aliasing properties using alias.scope + // metadata (and, thus, won't add any). + if (const Argument *A = dyn_cast<Argument>(V)) { + if (!A->hasNoAliasAttr()) + UsesAliasingPtr = true; + } else { + UsesAliasingPtr = true; + } + + // If this is not some identified function-local object (which cannot + // directly alias a noalias argument), or some other argument (which, + // by definition, also cannot alias a noalias argument), then we could + // alias a noalias argument that has been captured). + if (!isa<Argument>(V) && + !isIdentifiedFunctionLocal(const_cast<Value*>(V))) + CanDeriveViaCapture = true; + } + + // A function call can always get captured noalias pointers (via other + // parameters, globals, etc.). + if (IsFuncCall && !IsArgMemOnlyCall) + CanDeriveViaCapture = true; + + // First, we want to figure out all of the sets with which we definitely + // don't alias. Iterate over all noalias set, and add those for which: + // 1. The noalias argument is not in the set of objects from which we + // definitely derive. + // 2. The noalias argument has not yet been captured. + // An arbitrary function that might load pointers could see captured + // noalias arguments via other noalias arguments or globals, and so we + // must always check for prior capture. + for (const Argument *A : NoAliasArgs) { + if (!ObjSet.count(A) && (!CanDeriveViaCapture || + // It might be tempting to skip the + // PointerMayBeCapturedBefore check if + // A->hasNoCaptureAttr() is true, but this is + // incorrect because nocapture only guarantees + // that no copies outlive the function, not + // that the value cannot be locally captured. + !PointerMayBeCapturedBefore(A, + /* ReturnCaptures */ false, + /* StoreCaptures */ false, I, &DT))) + NoAliases.push_back(NewScopes[A]); + } + + if (!NoAliases.empty()) + NI->setMetadata(LLVMContext::MD_noalias, + MDNode::concatenate( + NI->getMetadata(LLVMContext::MD_noalias), + MDNode::get(CalledFunc->getContext(), NoAliases))); + + // Next, we want to figure out all of the sets to which we might belong. + // We might belong to a set if the noalias argument is in the set of + // underlying objects. If there is some non-noalias argument in our list + // of underlying objects, then we cannot add a scope because the fact + // that some access does not alias with any set of our noalias arguments + // cannot itself guarantee that it does not alias with this access + // (because there is some pointer of unknown origin involved and the + // other access might also depend on this pointer). We also cannot add + // scopes to arbitrary functions unless we know they don't access any + // non-parameter pointer-values. + bool CanAddScopes = !UsesAliasingPtr; + if (CanAddScopes && IsFuncCall) + CanAddScopes = IsArgMemOnlyCall; + + if (CanAddScopes) + for (const Argument *A : NoAliasArgs) { + if (ObjSet.count(A)) + Scopes.push_back(NewScopes[A]); + } + + if (!Scopes.empty()) + NI->setMetadata( + LLVMContext::MD_alias_scope, + MDNode::concatenate(NI->getMetadata(LLVMContext::MD_alias_scope), + MDNode::get(CalledFunc->getContext(), Scopes))); + } + } +} + +/// If the inlined function has non-byval align arguments, then +/// add @llvm.assume-based alignment assumptions to preserve this information. +static void AddAlignmentAssumptions(CallSite CS, InlineFunctionInfo &IFI) { + if (!PreserveAlignmentAssumptions || !IFI.DL) + return; + + // To avoid inserting redundant assumptions, we should check for assumptions + // already in the caller. To do this, we might need a DT of the caller. + DominatorTree DT; + bool DTCalculated = false; + + Function *CalledFunc = CS.getCalledFunction(); + for (Function::arg_iterator I = CalledFunc->arg_begin(), + E = CalledFunc->arg_end(); + I != E; ++I) { + unsigned Align = I->getType()->isPointerTy() ? I->getParamAlignment() : 0; + if (Align && !I->hasByValOrInAllocaAttr() && !I->hasNUses(0)) { + if (!DTCalculated) { + DT.recalculate(const_cast<Function&>(*CS.getInstruction()->getParent() + ->getParent())); + DTCalculated = true; + } + + // If we can already prove the asserted alignment in the context of the + // caller, then don't bother inserting the assumption. + Value *Arg = CS.getArgument(I->getArgNo()); + if (getKnownAlignment(Arg, IFI.DL, + &IFI.ACT->getAssumptionCache(*CalledFunc), + CS.getInstruction(), &DT) >= Align) + continue; + + IRBuilder<>(CS.getInstruction()).CreateAlignmentAssumption(*IFI.DL, Arg, + Align); + } + } +} + /// UpdateCallGraphAfterInlining - Once we have cloned code over from a callee /// into the caller, update the specified callgraph to reflect the changes we /// made. Note that it's possible that not all code was copied over, so only @@ -327,31 +726,19 @@ static void UpdateCallGraphAfterInlining(CallSite CS, static void HandleByValArgumentInit(Value *Dst, Value *Src, Module *M, BasicBlock *InsertBlock, InlineFunctionInfo &IFI) { - LLVMContext &Context = Src->getContext(); - Type *VoidPtrTy = Type::getInt8PtrTy(Context); Type *AggTy = cast<PointerType>(Src->getType())->getElementType(); - Type *Tys[3] = { VoidPtrTy, VoidPtrTy, Type::getInt64Ty(Context) }; - Function *MemCpyFn = Intrinsic::getDeclaration(M, Intrinsic::memcpy, Tys); - IRBuilder<> builder(InsertBlock->begin()); - Value *DstCast = builder.CreateBitCast(Dst, VoidPtrTy, "tmp"); - Value *SrcCast = builder.CreateBitCast(Src, VoidPtrTy, "tmp"); + IRBuilder<> Builder(InsertBlock->begin()); Value *Size; if (IFI.DL == nullptr) Size = ConstantExpr::getSizeOf(AggTy); else - Size = ConstantInt::get(Type::getInt64Ty(Context), - IFI.DL->getTypeStoreSize(AggTy)); + Size = Builder.getInt64(IFI.DL->getTypeStoreSize(AggTy)); // Always generate a memcpy of alignment 1 here because we don't know // the alignment of the src pointer. Other optimizations can infer // better alignment. - Value *CallArgs[] = { - DstCast, SrcCast, Size, - ConstantInt::get(Type::getInt32Ty(Context), 1), - ConstantInt::getFalse(Context) // isVolatile - }; - builder.CreateCall(MemCpyFn, CallArgs); + Builder.CreateMemCpy(Dst, Src, Size, /*Align=*/1); } /// HandleByValArgument - When inlining a call site that has a byval argument, @@ -363,6 +750,8 @@ static Value *HandleByValArgument(Value *Arg, Instruction *TheCall, PointerType *ArgTy = cast<PointerType>(Arg->getType()); Type *AggTy = ArgTy->getElementType(); + Function *Caller = TheCall->getParent()->getParent(); + // If the called function is readonly, then it could not mutate the caller's // copy of the byval'd memory. In this case, it is safe to elide the copy and // temporary. @@ -375,8 +764,9 @@ static Value *HandleByValArgument(Value *Arg, Instruction *TheCall, // If the pointer is already known to be sufficiently aligned, or if we can // round it up to a larger alignment, then we don't need a temporary. - if (getOrEnforceKnownAlignment(Arg, ByValAlignment, - IFI.DL) >= ByValAlignment) + if (getOrEnforceKnownAlignment(Arg, ByValAlignment, IFI.DL, + &IFI.ACT->getAssumptionCache(*Caller), + TheCall) >= ByValAlignment) return Arg; // Otherwise, we have to make a memcpy to get a safe alignment. This is bad @@ -393,8 +783,6 @@ static Value *HandleByValArgument(Value *Arg, Instruction *TheCall, // pointer inside the callee). Align = std::max(Align, ByValAlignment); - Function *Caller = TheCall->getParent()->getParent(); - Value *NewAlloca = new AllocaInst(AggTy, nullptr, Align, Arg->getName(), &*Caller->begin()->begin()); IFI.StaticAllocas.push_back(cast<AllocaInst>(NewAlloca)); @@ -472,47 +860,33 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI, // originates from the call location. This is important for // ((__always_inline__, __nodebug__)) functions which must use caller // location for all instructions in their function body. + + // Don't update static allocas, as they may get moved later. + if (auto *AI = dyn_cast<AllocaInst>(BI)) + if (isa<Constant>(AI->getArraySize())) + continue; + BI->setDebugLoc(TheCallDL); } else { BI->setDebugLoc(updateInlinedAtInfo(DL, TheCallDL, BI->getContext())); if (DbgValueInst *DVI = dyn_cast<DbgValueInst>(BI)) { LLVMContext &Ctx = BI->getContext(); MDNode *InlinedAt = BI->getDebugLoc().getInlinedAt(Ctx); - DVI->setOperand(2, createInlinedVariable(DVI->getVariable(), - InlinedAt, Ctx)); + DVI->setOperand(2, MetadataAsValue::get( + Ctx, createInlinedVariable(DVI->getVariable(), + InlinedAt, Ctx))); + } else if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(BI)) { + LLVMContext &Ctx = BI->getContext(); + MDNode *InlinedAt = BI->getDebugLoc().getInlinedAt(Ctx); + DDI->setOperand(1, MetadataAsValue::get( + Ctx, createInlinedVariable(DDI->getVariable(), + InlinedAt, Ctx))); } } } } } -/// Returns a musttail call instruction if one immediately precedes the given -/// return instruction with an optional bitcast instruction between them. -static CallInst *getPrecedingMustTailCall(ReturnInst *RI) { - Instruction *Prev = RI->getPrevNode(); - if (!Prev) - return nullptr; - - if (Value *RV = RI->getReturnValue()) { - if (RV != Prev) - return nullptr; - - // Look through the optional bitcast. - if (auto *BI = dyn_cast<BitCastInst>(Prev)) { - RV = BI->getOperand(0); - Prev = BI->getPrevNode(); - if (!Prev || RV != Prev) - return nullptr; - } - } - - if (auto *CI = dyn_cast<CallInst>(Prev)) { - if (CI->isMustTailCall()) - return CI; - } - return nullptr; -} - /// InlineFunction - This function inlines the called function into the basic /// block of the caller. This returns false if it is not possible to inline /// this call. The program is still in a well defined state if this occurs @@ -626,6 +1000,11 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, VMap[I] = ActualArg; } + // Add alignment assumptions if necessary. We do this before the inlined + // instructions are actually cloned into the caller so that we can easily + // check what will be known at the start of the inlined code. + AddAlignmentAssumptions(CS, IFI); + // We want the inliner to prune the code as it copies. We would LOVE to // have no dead or constant instructions leftover after inlining occurs // (which can happen, e.g., because an argument was constant), but we'll be @@ -648,6 +1027,17 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // Update inlined instructions' line number information. fixupLineNumbers(Caller, FirstNewBlock, TheCall); + + // Clone existing noalias metadata if necessary. + CloneAliasScopeMetadata(CS, VMap); + + // Add noalias metadata if necessary. + AddAliasScopeMetadata(CS, VMap, IFI.DL, IFI.AA); + + // FIXME: We could register any cloned assumptions instead of clearing the + // whole function's cache. + if (IFI.ACT) + IFI.ACT->getAssumptionCache(*Caller).clear(); } // If there are any alloca instructions in the block that used to be the entry @@ -765,7 +1155,8 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, for (ReturnInst *RI : Returns) { // Don't insert llvm.lifetime.end calls between a musttail call and a // return. The return kills all local allocas. - if (InlinedMustTailCalls && getPrecedingMustTailCall(RI)) + if (InlinedMustTailCalls && + RI->getParent()->getTerminatingMustTailCall()) continue; IRBuilder<>(RI).CreateLifetimeEnd(AI, AllocaSize); } @@ -789,7 +1180,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, for (ReturnInst *RI : Returns) { // Don't insert llvm.stackrestore calls between a musttail call and a // return. The return will restore the stack pointer. - if (InlinedMustTailCalls && getPrecedingMustTailCall(RI)) + if (InlinedMustTailCalls && RI->getParent()->getTerminatingMustTailCall()) continue; IRBuilder<>(RI).CreateCall(StackRestore, SavedPtr); } @@ -812,7 +1203,8 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // Handle the returns preceded by musttail calls separately. SmallVector<ReturnInst *, 8> NormalReturns; for (ReturnInst *RI : Returns) { - CallInst *ReturnedMustTail = getPrecedingMustTailCall(RI); + CallInst *ReturnedMustTail = + RI->getParent()->getTerminatingMustTailCall(); if (!ReturnedMustTail) { NormalReturns.push_back(RI); continue; @@ -1016,7 +1408,8 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // the entries are the same or undef). If so, remove the PHI so it doesn't // block other optimizations. if (PHI) { - if (Value *V = SimplifyInstruction(PHI, IFI.DL)) { + if (Value *V = SimplifyInstruction(PHI, IFI.DL, nullptr, nullptr, + &IFI.ACT->getAssumptionCache(*Caller))) { PHI->replaceAllUsesWith(V); PHI->eraseFromParent(); } diff --git a/lib/Transforms/Utils/IntegerDivision.cpp b/lib/Transforms/Utils/IntegerDivision.cpp index 9f91eeb79531..0ae746cc83db 100644 --- a/lib/Transforms/Utils/IntegerDivision.cpp +++ b/lib/Transforms/Utils/IntegerDivision.cpp @@ -398,11 +398,13 @@ bool llvm::expandRemainder(BinaryOperator *Rem) { Rem->dropAllReferences(); Rem->eraseFromParent(); - // If we didn't actually generate a udiv instruction, we're done - BinaryOperator *BO = dyn_cast<BinaryOperator>(Builder.GetInsertPoint()); - if (!BO || BO->getOpcode() != Instruction::URem) + // If we didn't actually generate an urem instruction, we're done + // This happens for example if the input were constant. In this case the + // Builder insertion point was unchanged + if (Rem == Builder.GetInsertPoint()) return true; + BinaryOperator *BO = dyn_cast<BinaryOperator>(Builder.GetInsertPoint()); Rem = BO; } @@ -456,11 +458,13 @@ bool llvm::expandDivision(BinaryOperator *Div) { Div->dropAllReferences(); Div->eraseFromParent(); - // If we didn't actually generate a udiv instruction, we're done - BinaryOperator *BO = dyn_cast<BinaryOperator>(Builder.GetInsertPoint()); - if (!BO || BO->getOpcode() != Instruction::UDiv) + // If we didn't actually generate an udiv instruction, we're done + // This happens for example if the input were constant. In this case the + // Builder insertion point was unchanged + if (Div == Builder.GetInsertPoint()) return true; + BinaryOperator *BO = dyn_cast<BinaryOperator>(Builder.GetInsertPoint()); Div = BO; } diff --git a/lib/Transforms/Utils/LCSSA.cpp b/lib/Transforms/Utils/LCSSA.cpp index 51a3d9c1fced..3f9b702c5b9a 100644 --- a/lib/Transforms/Utils/LCSSA.cpp +++ b/lib/Transforms/Utils/LCSSA.cpp @@ -61,7 +61,7 @@ static bool isExitBlock(BasicBlock *BB, /// uses. static bool processInstruction(Loop &L, Instruction &Inst, DominatorTree &DT, const SmallVectorImpl<BasicBlock *> &ExitBlocks, - PredIteratorCache &PredCache) { + PredIteratorCache &PredCache, LoopInfo *LI) { SmallVector<Use *, 16> UsesToRewrite; BasicBlock *InstBB = Inst.getParent(); @@ -94,6 +94,7 @@ static bool processInstruction(Loop &L, Instruction &Inst, DominatorTree &DT, DomTreeNode *DomNode = DT.getNode(DomBB); SmallVector<PHINode *, 16> AddedPHIs; + SmallVector<PHINode *, 8> PostProcessPHIs; SSAUpdater SSAUpdate; SSAUpdate.Initialize(Inst.getType(), Inst.getName()); @@ -131,6 +132,18 @@ static bool processInstruction(Loop &L, Instruction &Inst, DominatorTree &DT, // Remember that this phi makes the value alive in this block. SSAUpdate.AddAvailableValue(ExitBB, PN); + + // LoopSimplify might fail to simplify some loops (e.g. when indirect + // branches are involved). In such situations, it might happen that an exit + // for Loop L1 is the header of a disjoint Loop L2. Thus, when we create + // PHIs in such an exit block, we are also inserting PHIs into L2's header. + // This could break LCSSA form for L2 because these inserted PHIs can also + // have uses outside of L2. Remember all PHIs in such situation as to + // revisit than later on. FIXME: Remove this if indirectbr support into + // LoopSimplify gets improved. + if (auto *OtherLoop = LI->getLoopFor(ExitBB)) + if (!L.contains(OtherLoop)) + PostProcessPHIs.push_back(PN); } // Rewrite all uses outside the loop in terms of the new PHIs we just @@ -157,6 +170,25 @@ static bool processInstruction(Loop &L, Instruction &Inst, DominatorTree &DT, SSAUpdate.RewriteUse(*UsesToRewrite[i]); } + // Post process PHI instructions that were inserted into another disjoint loop + // and update their exits properly. + for (auto *I : PostProcessPHIs) { + if (I->use_empty()) + continue; + + BasicBlock *PHIBB = I->getParent(); + Loop *OtherLoop = LI->getLoopFor(PHIBB); + SmallVector<BasicBlock *, 8> EBs; + OtherLoop->getExitBlocks(EBs); + if (EBs.empty()) + continue; + + // Recurse and re-process each PHI instruction. FIXME: we should really + // convert this entire thing to a worklist approach where we process a + // vector of instructions... + processInstruction(*OtherLoop, *I, DT, EBs, PredCache, LI); + } + // Remove PHI nodes that did not have any uses rewritten. for (unsigned i = 0, e = AddedPHIs.size(); i != e; ++i) { if (AddedPHIs[i]->use_empty()) @@ -180,7 +212,8 @@ blockDominatesAnExit(BasicBlock *BB, return false; } -bool llvm::formLCSSA(Loop &L, DominatorTree &DT, ScalarEvolution *SE) { +bool llvm::formLCSSA(Loop &L, DominatorTree &DT, LoopInfo *LI, + ScalarEvolution *SE) { bool Changed = false; // Get the set of exiting blocks. @@ -212,7 +245,7 @@ bool llvm::formLCSSA(Loop &L, DominatorTree &DT, ScalarEvolution *SE) { !isa<PHINode>(I->user_back()))) continue; - Changed |= processInstruction(L, *I, DT, ExitBlocks, PredCache); + Changed |= processInstruction(L, *I, DT, ExitBlocks, PredCache, LI); } } @@ -228,15 +261,15 @@ bool llvm::formLCSSA(Loop &L, DominatorTree &DT, ScalarEvolution *SE) { } /// Process a loop nest depth first. -bool llvm::formLCSSARecursively(Loop &L, DominatorTree &DT, +bool llvm::formLCSSARecursively(Loop &L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution *SE) { bool Changed = false; // Recurse depth-first through inner loops. - for (Loop::iterator LI = L.begin(), LE = L.end(); LI != LE; ++LI) - Changed |= formLCSSARecursively(**LI, DT, SE); + for (Loop::iterator I = L.begin(), E = L.end(); I != E; ++I) + Changed |= formLCSSARecursively(**I, DT, LI, SE); - Changed |= formLCSSA(L, DT, SE); + Changed |= formLCSSA(L, DT, LI, SE); return Changed; } @@ -291,7 +324,7 @@ bool LCSSA::runOnFunction(Function &F) { // Simplify each loop nest in the function. for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) - Changed |= formLCSSARecursively(**I, *DT, SE); + Changed |= formLCSSARecursively(**I, *DT, LI, SE); return Changed; } diff --git a/lib/Transforms/Utils/Local.cpp b/lib/Transforms/Utils/Local.cpp index a5e443fcf46b..08a4b3f3b737 100644 --- a/lib/Transforms/Utils/Local.cpp +++ b/lib/Transforms/Utils/Local.cpp @@ -128,7 +128,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, // Check to see if this branch is going to the same place as the default // dest. If so, eliminate it as an explicit compare. if (i.getCaseSuccessor() == DefaultDest) { - MDNode* MD = SI->getMetadata(LLVMContext::MD_prof); + MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); unsigned NCases = SI->getNumCases(); // Fold the case metadata into the default if there will be any branches // left, unless the metadata doesn't match the switch. @@ -137,7 +137,8 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, SmallVector<uint32_t, 8> Weights; for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e; ++MD_i) { - ConstantInt* CI = dyn_cast<ConstantInt>(MD->getOperand(MD_i)); + ConstantInt *CI = + mdconst::dyn_extract<ConstantInt>(MD->getOperand(MD_i)); assert(CI); Weights.push_back(CI->getValue().getZExtValue()); } @@ -206,10 +207,12 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, BranchInst *NewBr = Builder.CreateCondBr(Cond, FirstCase.getCaseSuccessor(), SI->getDefaultDest()); - MDNode* MD = SI->getMetadata(LLVMContext::MD_prof); + MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); if (MD && MD->getNumOperands() == 3) { - ConstantInt *SICase = dyn_cast<ConstantInt>(MD->getOperand(2)); - ConstantInt *SIDef = dyn_cast<ConstantInt>(MD->getOperand(1)); + ConstantInt *SICase = + mdconst::dyn_extract<ConstantInt>(MD->getOperand(2)); + ConstantInt *SIDef = + mdconst::dyn_extract<ConstantInt>(MD->getOperand(1)); assert(SICase && SIDef); // The TrueWeight should be the weight for the single case of SI. NewBr->setMetadata(LLVMContext::MD_prof, @@ -301,6 +304,14 @@ bool llvm::isInstructionTriviallyDead(Instruction *I, if (II->getIntrinsicID() == Intrinsic::lifetime_start || II->getIntrinsicID() == Intrinsic::lifetime_end) return isa<UndefValue>(II->getArgOperand(1)); + + // Assumptions are dead if their condition is trivially true. + if (II->getIntrinsicID() == Intrinsic::assume) { + if (ConstantInt *Cond = dyn_cast<ConstantInt>(II->getArgOperand(0))) + return !Cond->isZero(); + + return false; + } } if (isAllocLikeFn(I, TLI)) return true; @@ -384,7 +395,7 @@ bool llvm::RecursivelyDeleteDeadPHINode(PHINode *PN, // If we find an instruction more than once, we're on a cycle that // won't prove fruitful. - if (!Visited.insert(I)) { + if (!Visited.insert(I).second) { // Break the cycle and delete the instruction and its operands. I->replaceAllUsesWith(UndefValue::get(I->getType())); (void)RecursivelyDeleteTriviallyDeadInstructions(I, TLI); @@ -931,13 +942,16 @@ static unsigned enforceKnownAlignment(Value *V, unsigned Align, /// and it is more than the alignment of the ultimate object, see if we can /// increase the alignment of the ultimate object, making this check succeed. unsigned llvm::getOrEnforceKnownAlignment(Value *V, unsigned PrefAlign, - const DataLayout *DL) { + const DataLayout *DL, + AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT) { assert(V->getType()->isPointerTy() && "getOrEnforceKnownAlignment expects a pointer!"); unsigned BitWidth = DL ? DL->getPointerTypeSizeInBits(V->getType()) : 64; APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(V, KnownZero, KnownOne, DL); + computeKnownBits(V, KnownZero, KnownOne, DL, 0, AC, CxtI, DT); unsigned TrailZ = KnownZero.countTrailingOnes(); // Avoid trouble with ridiculously large TrailZ values, such as @@ -982,6 +996,7 @@ static bool LdStHasDebugValue(DIVariable &DIVar, Instruction *I) { bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, StoreInst *SI, DIBuilder &Builder) { DIVariable DIVar(DDI->getVariable()); + DIExpression DIExpr(DDI->getExpression()); assert((!DIVar || DIVar.isVariable()) && "Variable in DbgDeclareInst should be either null or a DIVariable."); if (!DIVar) @@ -999,9 +1014,10 @@ bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, if (SExtInst *SExt = dyn_cast<SExtInst>(SI->getOperand(0))) ExtendedArg = dyn_cast<Argument>(SExt->getOperand(0)); if (ExtendedArg) - DbgVal = Builder.insertDbgValueIntrinsic(ExtendedArg, 0, DIVar, SI); + DbgVal = Builder.insertDbgValueIntrinsic(ExtendedArg, 0, DIVar, DIExpr, SI); else - DbgVal = Builder.insertDbgValueIntrinsic(SI->getOperand(0), 0, DIVar, SI); + DbgVal = Builder.insertDbgValueIntrinsic(SI->getOperand(0), 0, DIVar, + DIExpr, SI); DbgVal->setDebugLoc(DDI->getDebugLoc()); return true; } @@ -1011,6 +1027,7 @@ bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, LoadInst *LI, DIBuilder &Builder) { DIVariable DIVar(DDI->getVariable()); + DIExpression DIExpr(DDI->getExpression()); assert((!DIVar || DIVar.isVariable()) && "Variable in DbgDeclareInst should be either null or a DIVariable."); if (!DIVar) @@ -1020,8 +1037,7 @@ bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, return true; Instruction *DbgVal = - Builder.insertDbgValueIntrinsic(LI->getOperand(0), 0, - DIVar, LI); + Builder.insertDbgValueIntrinsic(LI->getOperand(0), 0, DIVar, DIExpr, LI); DbgVal->setDebugLoc(DDI->getDebugLoc()); return true; } @@ -1035,7 +1051,7 @@ static bool isArray(AllocaInst *AI) { /// LowerDbgDeclare - Lowers llvm.dbg.declare intrinsics into appropriate set /// of llvm.dbg.value intrinsics. bool llvm::LowerDbgDeclare(Function &F) { - DIBuilder DIB(*F.getParent()); + DIBuilder DIB(*F.getParent(), /*AllowUnresolved*/ false); SmallVector<DbgDeclareInst *, 4> Dbgs; for (auto &FI : F) for (BasicBlock::iterator BI : FI) @@ -1061,14 +1077,14 @@ bool llvm::LowerDbgDeclare(Function &F) { else if (LoadInst *LI = dyn_cast<LoadInst>(U)) ConvertDebugDeclareToDebugValue(DDI, LI, DIB); else if (CallInst *CI = dyn_cast<CallInst>(U)) { - // This is a call by-value or some other instruction that - // takes a pointer to the variable. Insert a *value* - // intrinsic that describes the alloca. - auto DbgVal = - DIB.insertDbgValueIntrinsic(AI, 0, - DIVariable(DDI->getVariable()), CI); - DbgVal->setDebugLoc(DDI->getDebugLoc()); - } + // This is a call by-value or some other instruction that + // takes a pointer to the variable. Insert a *value* + // intrinsic that describes the alloca. + auto DbgVal = DIB.insertDbgValueIntrinsic( + AI, 0, DIVariable(DDI->getVariable()), + DIExpression(DDI->getExpression()), CI); + DbgVal->setDebugLoc(DDI->getDebugLoc()); + } DDI->eraseFromParent(); } } @@ -1078,10 +1094,11 @@ bool llvm::LowerDbgDeclare(Function &F) { /// FindAllocaDbgDeclare - Finds the llvm.dbg.declare intrinsic describing the /// alloca 'V', if any. DbgDeclareInst *llvm::FindAllocaDbgDeclare(Value *V) { - if (MDNode *DebugNode = MDNode::getIfExists(V->getContext(), V)) - for (User *U : DebugNode->users()) - if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(U)) - return DDI; + if (auto *L = LocalAsMetadata::getIfExists(V)) + if (auto *MDV = MetadataAsValue::getIfExists(V->getContext(), L)) + for (User *U : MDV->users()) + if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(U)) + return DDI; return nullptr; } @@ -1092,33 +1109,27 @@ bool llvm::replaceDbgDeclareForAlloca(AllocaInst *AI, Value *NewAllocaAddress, if (!DDI) return false; DIVariable DIVar(DDI->getVariable()); + DIExpression DIExpr(DDI->getExpression()); assert((!DIVar || DIVar.isVariable()) && "Variable in DbgDeclareInst should be either null or a DIVariable."); if (!DIVar) return false; - // Create a copy of the original DIDescriptor for user variable, appending + // Create a copy of the original DIDescriptor for user variable, prepending // "deref" operation to a list of address elements, as new llvm.dbg.declare // will take a value storing address of the memory for variable, not // alloca itself. - Type *Int64Ty = Type::getInt64Ty(AI->getContext()); - SmallVector<Value*, 4> NewDIVarAddress; - if (DIVar.hasComplexAddress()) { - for (unsigned i = 0, n = DIVar.getNumAddrElements(); i < n; ++i) { - NewDIVarAddress.push_back( - ConstantInt::get(Int64Ty, DIVar.getAddrElement(i))); - } - } - NewDIVarAddress.push_back(ConstantInt::get(Int64Ty, DIBuilder::OpDeref)); - DIVariable NewDIVar = Builder.createComplexVariable( - DIVar.getTag(), DIVar.getContext(), DIVar.getName(), - DIVar.getFile(), DIVar.getLineNumber(), DIVar.getType(), - NewDIVarAddress, DIVar.getArgNumber()); + SmallVector<int64_t, 4> NewDIExpr; + NewDIExpr.push_back(dwarf::DW_OP_deref); + if (DIExpr) + for (unsigned i = 0, n = DIExpr.getNumElements(); i < n; ++i) + NewDIExpr.push_back(DIExpr.getElement(i)); // Insert llvm.dbg.declare in the same basic block as the original alloca, // and remove old llvm.dbg.declare. BasicBlock *BB = AI->getParent(); - Builder.insertDeclare(NewAllocaAddress, NewDIVar, BB); + Builder.insertDeclare(NewAllocaAddress, DIVar, + Builder.createExpression(NewDIExpr), BB); DDI->eraseFromParent(); return true; } @@ -1170,7 +1181,7 @@ static void changeToCall(InvokeInst *II) { } static bool markAliveBlocks(BasicBlock *BB, - SmallPtrSet<BasicBlock*, 128> &Reachable) { + SmallPtrSetImpl<BasicBlock*> &Reachable) { SmallVector<BasicBlock*, 128> Worklist; Worklist.push_back(BB); @@ -1183,6 +1194,26 @@ static bool markAliveBlocks(BasicBlock *BB, // instructions into LLVM unreachable insts. The instruction combining pass // canonicalizes unreachable insts into stores to null or undef. for (BasicBlock::iterator BBI = BB->begin(), E = BB->end(); BBI != E;++BBI){ + // Assumptions that are known to be false are equivalent to unreachable. + // Also, if the condition is undefined, then we make the choice most + // beneficial to the optimizer, and choose that to also be unreachable. + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(BBI)) + if (II->getIntrinsicID() == Intrinsic::assume) { + bool MakeUnreachable = false; + if (isa<UndefValue>(II->getArgOperand(0))) + MakeUnreachable = true; + else if (ConstantInt *Cond = + dyn_cast<ConstantInt>(II->getArgOperand(0))) + MakeUnreachable = Cond->isZero(); + + if (MakeUnreachable) { + // Don't insert a call to llvm.trap right before the unreachable. + changeToUnreachable(BBI, false); + Changed = true; + break; + } + } + if (CallInst *CI = dyn_cast<CallInst>(BBI)) { if (CI->doesNotReturn()) { // If we found a call to a no-return function, insert an unreachable @@ -1237,7 +1268,7 @@ static bool markAliveBlocks(BasicBlock *BB, Changed |= ConstantFoldTerminator(BB, true); for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI) - if (Reachable.insert(*SI)) + if (Reachable.insert(*SI).second) Worklist.push_back(*SI); } while (!Worklist.empty()); return Changed; @@ -1277,3 +1308,43 @@ bool llvm::removeUnreachableBlocks(Function &F) { return true; } + +void llvm::combineMetadata(Instruction *K, const Instruction *J, ArrayRef<unsigned> KnownIDs) { + SmallVector<std::pair<unsigned, MDNode *>, 4> Metadata; + K->dropUnknownMetadata(KnownIDs); + K->getAllMetadataOtherThanDebugLoc(Metadata); + for (unsigned i = 0, n = Metadata.size(); i < n; ++i) { + unsigned Kind = Metadata[i].first; + MDNode *JMD = J->getMetadata(Kind); + MDNode *KMD = Metadata[i].second; + + switch (Kind) { + default: + K->setMetadata(Kind, nullptr); // Remove unknown metadata + break; + case LLVMContext::MD_dbg: + llvm_unreachable("getAllMetadataOtherThanDebugLoc returned a MD_dbg"); + case LLVMContext::MD_tbaa: + K->setMetadata(Kind, MDNode::getMostGenericTBAA(JMD, KMD)); + break; + case LLVMContext::MD_alias_scope: + case LLVMContext::MD_noalias: + K->setMetadata(Kind, MDNode::intersect(JMD, KMD)); + break; + case LLVMContext::MD_range: + K->setMetadata(Kind, MDNode::getMostGenericRange(JMD, KMD)); + break; + case LLVMContext::MD_fpmath: + K->setMetadata(Kind, MDNode::getMostGenericFPMath(JMD, KMD)); + break; + case LLVMContext::MD_invariant_load: + // Only set the !invariant.load if it is present in both instructions. + K->setMetadata(Kind, JMD); + break; + case LLVMContext::MD_nonnull: + // Only set the !nonnull if it is present in both instructions. + K->setMetadata(Kind, JMD); + break; + } + } +} diff --git a/lib/Transforms/Utils/LoopSimplify.cpp b/lib/Transforms/Utils/LoopSimplify.cpp index ef422914b6b2..c832a4b36f50 100644 --- a/lib/Transforms/Utils/LoopSimplify.cpp +++ b/lib/Transforms/Utils/LoopSimplify.cpp @@ -44,6 +44,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/DependenceAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" @@ -173,8 +174,7 @@ static BasicBlock *rewriteLoopExitBlock(Loop *L, BasicBlock *Exit, Pass *PP) { if (Exit->isLandingPad()) { SmallVector<BasicBlock*, 2> NewBBs; - SplitLandingPadPredecessors(Exit, ArrayRef<BasicBlock*>(&LoopBlocks[0], - LoopBlocks.size()), + SplitLandingPadPredecessors(Exit, LoopBlocks, ".loopexit", ".nonloopexit", PP, NewBBs); NewExitBB = NewBBs[0]; @@ -209,11 +209,12 @@ static void addBlockAndPredsToSet(BasicBlock *InputBB, BasicBlock *StopBlock, /// \brief The first part of loop-nestification is to find a PHI node that tells /// us how to partition the loops. static PHINode *findPHIToPartitionLoops(Loop *L, AliasAnalysis *AA, - DominatorTree *DT) { + DominatorTree *DT, + AssumptionCache *AC) { for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ) { PHINode *PN = cast<PHINode>(I); ++I; - if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, DT)) { + if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, DT, AC)) { // This is a degenerate PHI already, don't modify it! PN->replaceAllUsesWith(V); if (AA) AA->deleteValue(PN); @@ -252,7 +253,8 @@ static PHINode *findPHIToPartitionLoops(Loop *L, AliasAnalysis *AA, /// static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, AliasAnalysis *AA, DominatorTree *DT, - LoopInfo *LI, ScalarEvolution *SE, Pass *PP) { + LoopInfo *LI, ScalarEvolution *SE, Pass *PP, + AssumptionCache *AC) { // Don't try to separate loops without a preheader. if (!Preheader) return nullptr; @@ -261,7 +263,7 @@ static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, assert(!L->getHeader()->isLandingPad() && "Can't insert backedge to landing pad"); - PHINode *PN = findPHIToPartitionLoops(L, AA, DT); + PHINode *PN = findPHIToPartitionLoops(L, AA, DT, AC); if (!PN) return nullptr; // No known way to partition. // Pull out all predecessors that have varying values in the loop. This @@ -474,8 +476,8 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, /// explicit if they accepted the analysis directly and then updated it. static bool simplifyOneLoop(Loop *L, SmallVectorImpl<Loop *> &Worklist, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, - ScalarEvolution *SE, Pass *PP, - const DataLayout *DL) { + ScalarEvolution *SE, Pass *PP, const DataLayout *DL, + AssumptionCache *AC) { bool Changed = false; ReprocessLoop: @@ -496,20 +498,19 @@ ReprocessLoop: } // Delete each unique out-of-loop (and thus dead) predecessor. - for (SmallPtrSet<BasicBlock*, 4>::iterator I = BadPreds.begin(), - E = BadPreds.end(); I != E; ++I) { + for (BasicBlock *P : BadPreds) { DEBUG(dbgs() << "LoopSimplify: Deleting edge from dead predecessor " - << (*I)->getName() << "\n"); + << P->getName() << "\n"); // Inform each successor of each dead pred. - for (succ_iterator SI = succ_begin(*I), SE = succ_end(*I); SI != SE; ++SI) - (*SI)->removePredecessor(*I); + for (succ_iterator SI = succ_begin(P), SE = succ_end(P); SI != SE; ++SI) + (*SI)->removePredecessor(P); // Zap the dead pred's terminator and replace it with unreachable. - TerminatorInst *TI = (*I)->getTerminator(); + TerminatorInst *TI = P->getTerminator(); TI->replaceAllUsesWith(UndefValue::get(TI->getType())); - (*I)->getTerminator()->eraseFromParent(); - new UnreachableInst((*I)->getContext(), *I); + P->getTerminator()->eraseFromParent(); + new UnreachableInst(P->getContext(), P); Changed = true; } } @@ -582,7 +583,8 @@ ReprocessLoop: // this for loops with a giant number of backedges, just factor them into a // common backedge instead. if (L->getNumBackEdges() < 8) { - if (Loop *OuterL = separateNestedLoop(L, Preheader, AA, DT, LI, SE, PP)) { + if (Loop *OuterL = + separateNestedLoop(L, Preheader, AA, DT, LI, SE, PP, AC)) { ++NumNested; // Enqueue the outer loop as it should be processed next in our // depth-first nest walk. @@ -612,7 +614,7 @@ ReprocessLoop: PHINode *PN; for (BasicBlock::iterator I = L->getHeader()->begin(); (PN = dyn_cast<PHINode>(I++)); ) - if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, DT)) { + if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, DT, AC)) { if (AA) AA->deleteValue(PN); if (SE) SE->forgetValue(PN); PN->replaceAllUsesWith(V); @@ -712,7 +714,7 @@ ReprocessLoop: bool llvm::simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, Pass *PP, AliasAnalysis *AA, ScalarEvolution *SE, - const DataLayout *DL) { + const DataLayout *DL, AssumptionCache *AC) { bool Changed = false; // Worklist maintains our depth-first queue of loops in this nest to process. @@ -730,7 +732,7 @@ bool llvm::simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, Pass *PP, while (!Worklist.empty()) Changed |= simplifyOneLoop(Worklist.pop_back_val(), Worklist, AA, DT, LI, - SE, PP, DL); + SE, PP, DL, AC); return Changed; } @@ -749,10 +751,13 @@ namespace { LoopInfo *LI; ScalarEvolution *SE; const DataLayout *DL; + AssumptionCache *AC; bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + // We need loop information to identify the loops... AU.addRequired<DominatorTreeWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); @@ -773,11 +778,12 @@ namespace { char LoopSimplify::ID = 0; INITIALIZE_PASS_BEGIN(LoopSimplify, "loop-simplify", - "Canonicalize natural loops", true, false) + "Canonicalize natural loops", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfo) INITIALIZE_PASS_END(LoopSimplify, "loop-simplify", - "Canonicalize natural loops", true, false) + "Canonicalize natural loops", false, false) // Publicly exposed interface to pass... char &llvm::LoopSimplifyID = LoopSimplify::ID; @@ -794,10 +800,11 @@ bool LoopSimplify::runOnFunction(Function &F) { SE = getAnalysisIfAvailable<ScalarEvolution>(); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; + AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); // Simplify each loop nest in the function. for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) - Changed |= simplifyLoop(*I, DT, LI, this, AA, SE, DL); + Changed |= simplifyLoop(*I, DT, LI, this, AA, SE, DL, AC); return Changed; } diff --git a/lib/Transforms/Utils/LoopUnroll.cpp b/lib/Transforms/Utils/LoopUnroll.cpp index ab1c25a75e26..57459206e528 100644 --- a/lib/Transforms/Utils/LoopUnroll.cpp +++ b/lib/Transforms/Utils/LoopUnroll.cpp @@ -19,6 +19,7 @@ #include "llvm/Transforms/Utils/UnrollLoop.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" @@ -111,7 +112,7 @@ FoldBlockIntoPredecessor(BasicBlock *BB, LoopInfo* LI, LPPassManager *LPM, if (LPM) { if (ScalarEvolution *SE = LPM->getAnalysisIfAvailable<ScalarEvolution>()) { if (Loop *L = LI->getLoopFor(BB)) { - if (ForgottenLoops.insert(L)) + if (ForgottenLoops.insert(L).second) SE->forgetLoop(L); } } @@ -153,8 +154,8 @@ FoldBlockIntoPredecessor(BasicBlock *BB, LoopInfo* LI, LPPassManager *LPM, /// This utility preserves LoopInfo. If DominatorTree or ScalarEvolution are /// available from the Pass it must also preserve those analyses. bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, - bool AllowRuntime, unsigned TripMultiple, - LoopInfo *LI, Pass *PP, LPPassManager *LPM) { + bool AllowRuntime, unsigned TripMultiple, LoopInfo *LI, + Pass *PP, LPPassManager *LPM, AssumptionCache *AC) { BasicBlock *Preheader = L->getLoopPreheader(); if (!Preheader) { DEBUG(dbgs() << " Can't unroll; loop preheader-insertion failed.\n"); @@ -222,11 +223,10 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, // Notify ScalarEvolution that the loop will be substantially changed, // if not outright eliminated. - if (PP) { - ScalarEvolution *SE = PP->getAnalysisIfAvailable<ScalarEvolution>(); - if (SE) - SE->forgetLoop(L); - } + ScalarEvolution *SE = + PP ? PP->getAnalysisIfAvailable<ScalarEvolution>() : nullptr; + if (SE) + SE->forgetLoop(L); // If we know the trip count, we know the multiple... unsigned BreakoutTrip = 0; @@ -300,15 +300,45 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, for (unsigned It = 1; It != Count; ++It) { std::vector<BasicBlock*> NewBlocks; + SmallDenseMap<const Loop *, Loop *, 4> NewLoops; + NewLoops[L] = L; for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) { ValueToValueMapTy VMap; BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It)); Header->getParent()->getBasicBlockList().push_back(New); - // Loop over all of the PHI nodes in the block, changing them to use the - // incoming values from the previous block. + // Tell LI about New. + if (*BB == Header) { + assert(LI->getLoopFor(*BB) == L && "Header should not be in a sub-loop"); + L->addBasicBlockToLoop(New, LI->getBase()); + } else { + // Figure out which loop New is in. + const Loop *OldLoop = LI->getLoopFor(*BB); + assert(OldLoop && "Should (at least) be in the loop being unrolled!"); + + Loop *&NewLoop = NewLoops[OldLoop]; + if (!NewLoop) { + // Found a new sub-loop. + assert(*BB == OldLoop->getHeader() && + "Header should be first in RPO"); + + Loop *NewLoopParent = NewLoops.lookup(OldLoop->getParentLoop()); + assert(NewLoopParent && + "Expected parent loop before sub-loop in RPO"); + NewLoop = new Loop; + NewLoopParent->addChildLoop(NewLoop); + + // Forget the old loop, since its inputs may have changed. + if (SE) + SE->forgetLoop(OldLoop); + } + NewLoop->addBasicBlockToLoop(New, LI->getBase()); + } + if (*BB == Header) + // Loop over all of the PHI nodes in the block, changing them to use + // the incoming values from the previous block. for (unsigned i = 0, e = OrigPHINode.size(); i != e; ++i) { PHINode *NewPHI = cast<PHINode>(VMap[OrigPHINode[i]]); Value *InVal = NewPHI->getIncomingValueForBlock(LatchBlock); @@ -325,8 +355,6 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, VI != VE; ++VI) LastValueMap[VI->first] = VI->second; - L->addBasicBlockToLoop(New, LI->getBase()); - // Add phi entries for newly created values to all exit blocks. for (succ_iterator SI = succ_begin(*BB), SE = succ_end(*BB); SI != SE; ++SI) { @@ -442,6 +470,10 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, } } + // FIXME: We could register any cloned assumptions instead of clearing the + // whole function's cache. + AC->clear(); + DominatorTree *DT = nullptr; if (PP) { // FIXME: Reconstruct dom info, because it is not preserved properly. @@ -453,7 +485,6 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, } // Simplify any new induction variables in the partially unrolled loop. - ScalarEvolution *SE = PP->getAnalysisIfAvailable<ScalarEvolution>(); if (SE && !CompletelyUnroll) { SmallVector<WeakVH, 16> DeadInsts; simplifyLoopIVs(L, SE, LPM, DeadInsts); @@ -502,8 +533,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, if (OuterL) { DataLayoutPass *DLP = PP->getAnalysisIfAvailable<DataLayoutPass>(); const DataLayout *DL = DLP ? &DLP->getDataLayout() : nullptr; - ScalarEvolution *SE = PP->getAnalysisIfAvailable<ScalarEvolution>(); - simplifyLoop(OuterL, DT, LI, PP, /*AliasAnalysis*/ nullptr, SE, DL); + simplifyLoop(OuterL, DT, LI, PP, /*AliasAnalysis*/ nullptr, SE, DL, AC); // LCSSA must be performed on the outermost affected loop. The unrolled // loop's last loop latch is guaranteed to be in the outermost loop after @@ -513,7 +543,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, while (OuterL->getParentLoop() != LatchLoop) OuterL = OuterL->getParentLoop(); - formLCSSARecursively(*OuterL, *DT, SE); + formLCSSARecursively(*OuterL, *DT, LI, SE); } } diff --git a/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/lib/Transforms/Utils/LoopUnrollRuntime.cpp index a96c46ad63e0..f12cd61d463a 100644 --- a/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -28,6 +28,7 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Metadata.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -57,7 +58,7 @@ STATISTIC(NumRuntimeUnrolled, static void ConnectProlog(Loop *L, Value *TripCount, unsigned Count, BasicBlock *LastPrologBB, BasicBlock *PrologEnd, BasicBlock *OrigPH, BasicBlock *NewPH, - ValueToValueMapTy &LVMap, Pass *P) { + ValueToValueMapTy &VMap, Pass *P) { BasicBlock *Latch = L->getLoopLatch(); assert(Latch && "Loop must have a latch"); @@ -86,7 +87,7 @@ static void ConnectProlog(Loop *L, Value *TripCount, unsigned Count, Value *V = PN->getIncomingValueForBlock(Latch); if (Instruction *I = dyn_cast<Instruction>(V)) { if (L->contains(I)) { - V = LVMap[I]; + V = VMap[I]; } } // Adding a value to the new PHI node from the last prolog block @@ -127,76 +128,123 @@ static void ConnectProlog(Loop *L, Value *TripCount, unsigned Count, } /// Create a clone of the blocks in a loop and connect them together. -/// This function doesn't create a clone of the loop structure. +/// If UnrollProlog is true, loop structure will not be cloned, otherwise a new +/// loop will be created including all cloned blocks, and the iterator of it +/// switches to count NewIter down to 0. /// -/// There are two value maps that are defined and used. VMap is -/// for the values in the current loop instance. LVMap contains -/// the values from the last loop instance. We need the LVMap values -/// to update the initial values for the current loop instance. -/// -static void CloneLoopBlocks(Loop *L, - bool FirstCopy, - BasicBlock *InsertTop, - BasicBlock *InsertBot, +static void CloneLoopBlocks(Loop *L, Value *NewIter, const bool UnrollProlog, + BasicBlock *InsertTop, BasicBlock *InsertBot, std::vector<BasicBlock *> &NewBlocks, - LoopBlocksDFS &LoopBlocks, - ValueToValueMapTy &VMap, - ValueToValueMapTy &LVMap, + LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap, LoopInfo *LI) { - BasicBlock *Preheader = L->getLoopPreheader(); BasicBlock *Header = L->getHeader(); BasicBlock *Latch = L->getLoopLatch(); Function *F = Header->getParent(); LoopBlocksDFS::RPOIterator BlockBegin = LoopBlocks.beginRPO(); LoopBlocksDFS::RPOIterator BlockEnd = LoopBlocks.endRPO(); + Loop *NewLoop = 0; + Loop *ParentLoop = L->getParentLoop(); + if (!UnrollProlog) { + NewLoop = new Loop(); + if (ParentLoop) + ParentLoop->addChildLoop(NewLoop); + else + LI->addTopLevelLoop(NewLoop); + } + // For each block in the original loop, create a new copy, // and update the value map with the newly created values. for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) { - BasicBlock *NewBB = CloneBasicBlock(*BB, VMap, ".unr", F); + BasicBlock *NewBB = CloneBasicBlock(*BB, VMap, ".prol", F); NewBlocks.push_back(NewBB); - if (Loop *ParentLoop = L->getParentLoop()) + if (NewLoop) + NewLoop->addBasicBlockToLoop(NewBB, LI->getBase()); + else if (ParentLoop) ParentLoop->addBasicBlockToLoop(NewBB, LI->getBase()); VMap[*BB] = NewBB; if (Header == *BB) { // For the first block, add a CFG connection to this newly - // created block + // created block. InsertTop->getTerminator()->setSuccessor(0, NewBB); - // Change the incoming values to the ones defined in the - // previously cloned loop. - for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ++I) { - PHINode *NewPHI = cast<PHINode>(VMap[I]); - if (FirstCopy) { - // We replace the first phi node with the value from the preheader - VMap[I] = NewPHI->getIncomingValueForBlock(Preheader); - NewBB->getInstList().erase(NewPHI); - } else { - // Update VMap with values from the previous block - unsigned idx = NewPHI->getBasicBlockIndex(Latch); - Value *InVal = NewPHI->getIncomingValue(idx); - if (Instruction *I = dyn_cast<Instruction>(InVal)) - if (L->contains(I)) - InVal = LVMap[InVal]; - NewPHI->setIncomingValue(idx, InVal); - NewPHI->setIncomingBlock(idx, InsertTop); - } - } } - if (Latch == *BB) { + // For the last block, if UnrollProlog is true, create a direct jump to + // InsertBot. If not, create a loop back to cloned head. VMap.erase((*BB)->getTerminator()); - NewBB->getTerminator()->eraseFromParent(); - BranchInst::Create(InsertBot, NewBB); + BasicBlock *FirstLoopBB = cast<BasicBlock>(VMap[Header]); + BranchInst *LatchBR = cast<BranchInst>(NewBB->getTerminator()); + if (UnrollProlog) { + LatchBR->eraseFromParent(); + BranchInst::Create(InsertBot, NewBB); + } else { + PHINode *NewIdx = PHINode::Create(NewIter->getType(), 2, "prol.iter", + FirstLoopBB->getFirstNonPHI()); + IRBuilder<> Builder(LatchBR); + Value *IdxSub = + Builder.CreateSub(NewIdx, ConstantInt::get(NewIdx->getType(), 1), + NewIdx->getName() + ".sub"); + Value *IdxCmp = + Builder.CreateIsNotNull(IdxSub, NewIdx->getName() + ".cmp"); + BranchInst::Create(FirstLoopBB, InsertBot, IdxCmp, NewBB); + NewIdx->addIncoming(NewIter, InsertTop); + NewIdx->addIncoming(IdxSub, NewBB); + LatchBR->eraseFromParent(); + } } } - // LastValueMap is updated with the values for the current loop - // which are used the next time this function is called. - for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end(); - VI != VE; ++VI) { - LVMap[VI->first] = VI->second; + + // Change the incoming values to the ones defined in the preheader or + // cloned loop. + for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ++I) { + PHINode *NewPHI = cast<PHINode>(VMap[I]); + if (UnrollProlog) { + VMap[I] = NewPHI->getIncomingValueForBlock(Preheader); + cast<BasicBlock>(VMap[Header])->getInstList().erase(NewPHI); + } else { + unsigned idx = NewPHI->getBasicBlockIndex(Preheader); + NewPHI->setIncomingBlock(idx, InsertTop); + BasicBlock *NewLatch = cast<BasicBlock>(VMap[Latch]); + idx = NewPHI->getBasicBlockIndex(Latch); + Value *InVal = NewPHI->getIncomingValue(idx); + NewPHI->setIncomingBlock(idx, NewLatch); + if (VMap[InVal]) + NewPHI->setIncomingValue(idx, VMap[InVal]); + } + } + if (NewLoop) { + // Add unroll disable metadata to disable future unrolling for this loop. + SmallVector<Metadata *, 4> MDs; + // Reserve first location for self reference to the LoopID metadata node. + MDs.push_back(nullptr); + MDNode *LoopID = NewLoop->getLoopID(); + if (LoopID) { + // First remove any existing loop unrolling metadata. + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + bool IsUnrollMetadata = false; + MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); + if (MD) { + const MDString *S = dyn_cast<MDString>(MD->getOperand(0)); + IsUnrollMetadata = S && S->getString().startswith("llvm.loop.unroll."); + } + if (!IsUnrollMetadata) + MDs.push_back(LoopID->getOperand(i)); + } + } + + LLVMContext &Context = NewLoop->getHeader()->getContext(); + SmallVector<Metadata *, 1> DisableOperands; + DisableOperands.push_back(MDString::get(Context, "llvm.loop.unroll.disable")); + MDNode *DisableNode = MDNode::get(Context, DisableOperands); + MDs.push_back(DisableNode); + + MDNode *NewLoopID = MDNode::get(Context, MDs); + // Set operand 0 to refer to the loop id itself. + NewLoopID->replaceOperandWith(0, NewLoopID); + NewLoop->setLoopID(NewLoopID); } } @@ -212,18 +260,16 @@ static void CloneLoopBlocks(Loop *L, /// instruction in SimplifyCFG.cpp. Then, the backend decides how code for /// the switch instruction is generated. /// -/// extraiters = tripcount % loopfactor -/// if (extraiters == 0) jump Loop: -/// if (extraiters == loopfactor) jump L1 -/// if (extraiters == loopfactor-1) jump L2 -/// ... -/// L1: LoopBody; -/// L2: LoopBody; -/// ... -/// if tripcount < loopfactor jump End -/// Loop: -/// ... -/// End: +/// extraiters = tripcount % loopfactor +/// if (extraiters == 0) jump Loop: +/// else jump Prol +/// Prol: LoopBody; +/// extraiters -= 1 // Omitted if unroll factor is 2. +/// if (extraiters != 0) jump Prol: // Omitted if unroll factor is 2. +/// if (tripcount < loopfactor) jump End +/// Loop: +/// ... +/// End: /// bool llvm::UnrollRuntimeLoopProlog(Loop *L, unsigned Count, LoopInfo *LI, LPPassManager *LPM) { @@ -250,6 +296,10 @@ bool llvm::UnrollRuntimeLoopProlog(Loop *L, unsigned Count, LoopInfo *LI, if (isa<SCEVCouldNotCompute>(BECount) || !BECount->getType()->isIntegerTy()) return false; + // If BECount is INT_MAX, we can't compute trip-count without overflow. + if (BECount->isAllOnesValue()) + return false; + // Add 1 since the backedge count doesn't include the first loop iteration const SCEV *TripCountSC = SE->getAddExpr(BECount, SE->getConstant(BECount->getType(), 1)); @@ -284,26 +334,21 @@ bool llvm::UnrollRuntimeLoopProlog(Loop *L, unsigned Count, LoopInfo *LI, IRBuilder<> B(PreHeaderBR); Value *ModVal = B.CreateAnd(TripCount, Count - 1, "xtraiter"); - // Check if for no extra iterations, then jump to unrolled loop. We have to - // check that the trip count computation didn't overflow when adding one to - // the backedge taken count. + // Check if for no extra iterations, then jump to cloned/unrolled loop. + // We have to check that the trip count computation didn't overflow when + // adding one to the backedge taken count. Value *LCmp = B.CreateIsNotNull(ModVal, "lcmp.mod"); Value *OverflowCheck = B.CreateIsNull(TripCount, "lcmp.overflow"); Value *BranchVal = B.CreateOr(OverflowCheck, LCmp, "lcmp.or"); - // Branch to either the extra iterations or the unrolled loop + // Branch to either the extra iterations or the cloned/unrolled loop // We will fix up the true branch label when adding loop body copies BranchInst::Create(PEnd, PEnd, BranchVal, PreHeaderBR); assert(PreHeaderBR->isUnconditional() && PreHeaderBR->getSuccessor(0) == PEnd && "CFG edges in Preheader are not correct"); PreHeaderBR->eraseFromParent(); - - ValueToValueMapTy LVMap; Function *F = Header->getParent(); - // These variables are used to update the CFG links in each iteration - BasicBlock *CompareBB = nullptr; - BasicBlock *LastLoopBB = PH; // Get an ordered list of blocks in the loop to help with the ordering of the // cloned blocks in the prolog code LoopBlocksDFS LoopBlocks(L); @@ -314,62 +359,39 @@ bool llvm::UnrollRuntimeLoopProlog(Loop *L, unsigned Count, LoopInfo *LI, // and generate a condition that branches to the copy depending on the // number of 'left over' iterations. // - for (unsigned leftOverIters = Count-1; leftOverIters > 0; --leftOverIters) { - std::vector<BasicBlock*> NewBlocks; - ValueToValueMapTy VMap; - - // Clone all the basic blocks in the loop, but we don't clone the loop - // This function adds the appropriate CFG connections. - CloneLoopBlocks(L, (leftOverIters == Count-1), LastLoopBB, PEnd, NewBlocks, - LoopBlocks, VMap, LVMap, LI); - LastLoopBB = cast<BasicBlock>(VMap[Latch]); - - // Insert the cloned blocks into function just before the original loop - F->getBasicBlockList().splice(PEnd, F->getBasicBlockList(), - NewBlocks[0], F->end()); - - // Generate the code for the comparison which determines if the loop - // prolog code needs to be executed. - if (leftOverIters == Count-1) { - // There is no compare block for the fall-thru case when for the last - // left over iteration - CompareBB = NewBlocks[0]; - } else { - // Create a new block for the comparison - BasicBlock *NewBB = BasicBlock::Create(CompareBB->getContext(), "unr.cmp", - F, CompareBB); - if (Loop *ParentLoop = L->getParentLoop()) { - // Add the new block to the parent loop, if needed - ParentLoop->addBasicBlockToLoop(NewBB, LI->getBase()); - } - - // The comparison w/ the extra iteration value and branch - Type *CountTy = TripCount->getType(); - Value *BranchVal = new ICmpInst(*NewBB, ICmpInst::ICMP_EQ, ModVal, - ConstantInt::get(CountTy, leftOverIters), - "un.tmp"); - // Branch to either the extra iterations or the unrolled loop - BranchInst::Create(NewBlocks[0], CompareBB, - BranchVal, NewBB); - CompareBB = NewBB; - PH->getTerminator()->setSuccessor(0, NewBB); - VMap[NewPH] = CompareBB; - } - - // Rewrite the cloned instruction operands to use the values - // created when the clone is created. - for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i) { - for (BasicBlock::iterator I = NewBlocks[i]->begin(), - E = NewBlocks[i]->end(); I != E; ++I) { - RemapInstruction(I, VMap, - RF_NoModuleLevelChanges|RF_IgnoreMissingEntries); - } + std::vector<BasicBlock *> NewBlocks; + ValueToValueMapTy VMap; + + // If unroll count is 2 and we can't overflow in tripcount computation (which + // is BECount + 1), then we don't need a loop for prologue, and we can unroll + // it. We can be sure that we don't overflow only if tripcount is a constant. + bool UnrollPrologue = (Count == 2 && isa<ConstantInt>(TripCount)); + + // Clone all the basic blocks in the loop. If Count is 2, we don't clone + // the loop, otherwise we create a cloned loop to execute the extra + // iterations. This function adds the appropriate CFG connections. + CloneLoopBlocks(L, ModVal, UnrollPrologue, PH, PEnd, NewBlocks, LoopBlocks, + VMap, LI); + + // Insert the cloned blocks into function just before the original loop + F->getBasicBlockList().splice(PEnd, F->getBasicBlockList(), NewBlocks[0], + F->end()); + + // Rewrite the cloned instruction operands to use the values + // created when the clone is created. + for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i) { + for (BasicBlock::iterator I = NewBlocks[i]->begin(), + E = NewBlocks[i]->end(); + I != E; ++I) { + RemapInstruction(I, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingEntries); } } // Connect the prolog code to the original loop and update the // PHI functions. - ConnectProlog(L, TripCount, Count, LastLoopBB, PEnd, PH, NewPH, LVMap, + BasicBlock *LastLoopBB = cast<BasicBlock>(VMap[Latch]); + ConnectProlog(L, TripCount, Count, LastLoopBB, PEnd, PH, NewPH, VMap, LPM->getAsPass()); NumRuntimeUnrolled++; return true; diff --git a/lib/Transforms/Utils/LowerSwitch.cpp b/lib/Transforms/Utils/LowerSwitch.cpp index d6e5bb626805..35cd917330ab 100644 --- a/lib/Transforms/Utils/LowerSwitch.cpp +++ b/lib/Transforms/Utils/LowerSwitch.cpp @@ -131,18 +131,39 @@ static raw_ostream& operator<<(raw_ostream &O, return O << "]"; } -static void fixPhis(BasicBlock *Succ, - BasicBlock *OrigBlock, - BasicBlock *NewNode) { - for (BasicBlock::iterator I = Succ->begin(), - E = Succ->getFirstNonPHI(); - I != E; ++I) { +// \brief Update the first occurrence of the "switch statement" BB in the PHI +// node with the "new" BB. The other occurrences will: +// +// 1) Be updated by subsequent calls to this function. Switch statements may +// have more than one outcoming edge into the same BB if they all have the same +// value. When the switch statement is converted these incoming edges are now +// coming from multiple BBs. +// 2) Removed if subsequent incoming values now share the same case, i.e., +// multiple outcome edges are condensed into one. This is necessary to keep the +// number of phi values equal to the number of branches to SuccBB. +static void fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, + unsigned NumMergedCases) { + for (BasicBlock::iterator I = SuccBB->begin(), IE = SuccBB->getFirstNonPHI(); + I != IE; ++I) { PHINode *PN = cast<PHINode>(I); - for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) { - if (PN->getIncomingBlock(I) == OrigBlock) - PN->setIncomingBlock(I, NewNode); + // Only update the first occurence. + unsigned Idx = 0, E = PN->getNumIncomingValues(); + unsigned LocalNumMergedCases = NumMergedCases; + for (; Idx != E; ++Idx) { + if (PN->getIncomingBlock(Idx) == OrigBB) { + PN->setIncomingBlock(Idx, NewBB); + break; + } } + + // Remove additional occurences coming from condensed cases and keep the + // number of incoming values equal to the number of branches to SuccBB. + for (++Idx; LocalNumMergedCases > 0 && Idx < E; ++Idx) + if (PN->getIncomingBlock(Idx) == OrigBB) { + PN->removeIncomingValue(Idx); + LocalNumMergedCases--; + } } } @@ -165,7 +186,11 @@ BasicBlock *LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, // emitting the code that checks if the value actually falls in the range // because the bounds already tell us so. if (Begin->Low == LowerBound && Begin->High == UpperBound) { - fixPhis(Begin->BB, OrigBlock, Predecessor); + unsigned NumMergedCases = 0; + if (LowerBound && UpperBound) + NumMergedCases = + UpperBound->getSExtValue() - LowerBound->getSExtValue(); + fixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases); return Begin->BB; } return newLeafBlock(*Begin, Val, OrigBlock, Default); diff --git a/lib/Transforms/Utils/Mem2Reg.cpp b/lib/Transforms/Utils/Mem2Reg.cpp index 189caa7d145a..00cf4e6c01c8 100644 --- a/lib/Transforms/Utils/Mem2Reg.cpp +++ b/lib/Transforms/Utils/Mem2Reg.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" @@ -38,6 +39,7 @@ namespace { bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.setPreservesCFG(); // This is a cluster of orthogonal Transforms @@ -51,6 +53,7 @@ namespace { char PromotePass::ID = 0; INITIALIZE_PASS_BEGIN(PromotePass, "mem2reg", "Promote Memory to Register", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_END(PromotePass, "mem2reg", "Promote Memory to Register", false, false) @@ -63,6 +66,8 @@ bool PromotePass::runOnFunction(Function &F) { bool Changed = false; DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + AssumptionCache &AC = + getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); while (1) { Allocas.clear(); @@ -76,7 +81,7 @@ bool PromotePass::runOnFunction(Function &F) { if (Allocas.empty()) break; - PromoteMemToReg(Allocas, DT); + PromoteMemToReg(Allocas, DT, nullptr, &AC); NumPromoted += Allocas.size(); Changed = true; } diff --git a/lib/Transforms/Utils/ModuleUtils.cpp b/lib/Transforms/Utils/ModuleUtils.cpp index d9dbbca1c366..35c701eeedc9 100644 --- a/lib/Transforms/Utils/ModuleUtils.cpp +++ b/lib/Transforms/Utils/ModuleUtils.cpp @@ -78,7 +78,7 @@ void llvm::appendToGlobalDtors(Module &M, Function *F, int Priority) { } GlobalVariable * -llvm::collectUsedGlobalVariables(Module &M, SmallPtrSet<GlobalValue *, 8> &Set, +llvm::collectUsedGlobalVariables(Module &M, SmallPtrSetImpl<GlobalValue *> &Set, bool CompilerUsed) { const char *Name = CompilerUsed ? "llvm.compiler.used" : "llvm.used"; GlobalVariable *GV = M.getGlobalVariable(Name); diff --git a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp index 06d73feb1cc8..dabadb794d4e 100644 --- a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp +++ b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -238,6 +238,9 @@ struct PromoteMem2Reg { /// An AliasSetTracker object to update. If null, don't update it. AliasSetTracker *AST; + /// A cache of @llvm.assume intrinsics used by SimplifyInstruction. + AssumptionCache *AC; + /// Reverse mapping of Allocas. DenseMap<AllocaInst *, unsigned> AllocaLookup; @@ -279,9 +282,10 @@ struct PromoteMem2Reg { public: PromoteMem2Reg(ArrayRef<AllocaInst *> Allocas, DominatorTree &DT, - AliasSetTracker *AST) + AliasSetTracker *AST, AssumptionCache *AC) : Allocas(Allocas.begin(), Allocas.end()), DT(DT), - DIB(*DT.getRoot()->getParent()->getParent()), AST(AST) {} + DIB(*DT.getRoot()->getParent()->getParent(), /*AllowUnresolved*/ false), + AST(AST), AC(AC) {} void run(); @@ -302,8 +306,8 @@ private: void DetermineInsertionPoint(AllocaInst *AI, unsigned AllocaNum, AllocaInfo &Info); void ComputeLiveInBlocks(AllocaInst *AI, AllocaInfo &Info, - const SmallPtrSet<BasicBlock *, 32> &DefBlocks, - SmallPtrSet<BasicBlock *, 32> &LiveInBlocks); + const SmallPtrSetImpl<BasicBlock *> &DefBlocks, + SmallPtrSetImpl<BasicBlock *> &LiveInBlocks); void RenamePass(BasicBlock *BB, BasicBlock *Pred, RenamePassData::ValVector &IncVals, std::vector<RenamePassData> &Worklist); @@ -412,7 +416,8 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, // Record debuginfo for the store and remove the declaration's // debuginfo. if (DbgDeclareInst *DDI = Info.DbgDeclare) { - DIBuilder DIB(*AI->getParent()->getParent()->getParent()); + DIBuilder DIB(*AI->getParent()->getParent()->getParent(), + /*AllowUnresolved*/ false); ConvertDebugDeclareToDebugValue(DDI, Info.OnlyStore, DIB); DDI->eraseFromParent(); LBI.deleteValue(DDI); @@ -495,7 +500,8 @@ static void promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, StoreInst *SI = cast<StoreInst>(AI->user_back()); // Record debuginfo for the store before removing it. if (DbgDeclareInst *DDI = Info.DbgDeclare) { - DIBuilder DIB(*AI->getParent()->getParent()->getParent()); + DIBuilder DIB(*AI->getParent()->getParent()->getParent(), + /*AllowUnresolved*/ false); ConvertDebugDeclareToDebugValue(DDI, SI, DIB); } SI->eraseFromParent(); @@ -685,7 +691,7 @@ void PromoteMem2Reg::run() { PHINode *PN = I->second; // If this PHI node merges one value and/or undefs, get the value. - if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, &DT)) { + if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, &DT, AC)) { if (AST && PN->getType()->isPointerTy()) AST->deleteValue(PN); PN->replaceAllUsesWith(V); @@ -766,8 +772,8 @@ void PromoteMem2Reg::run() { /// inserted phi nodes would be dead). void PromoteMem2Reg::ComputeLiveInBlocks( AllocaInst *AI, AllocaInfo &Info, - const SmallPtrSet<BasicBlock *, 32> &DefBlocks, - SmallPtrSet<BasicBlock *, 32> &LiveInBlocks) { + const SmallPtrSetImpl<BasicBlock *> &DefBlocks, + SmallPtrSetImpl<BasicBlock *> &LiveInBlocks) { // To determine liveness, we must iterate through the predecessors of blocks // where the def is live. Blocks are added to the worklist if we need to @@ -816,7 +822,7 @@ void PromoteMem2Reg::ComputeLiveInBlocks( // The block really is live in here, insert it into the set. If already in // the set, then it has already been processed. - if (!LiveInBlocks.insert(BB)) + if (!LiveInBlocks.insert(BB).second) continue; // Since the value is live into BB, it is either defined in a predecessor or @@ -857,10 +863,8 @@ void PromoteMem2Reg::DetermineInsertionPoint(AllocaInst *AI, unsigned AllocaNum, less_second> IDFPriorityQueue; IDFPriorityQueue PQ; - for (SmallPtrSet<BasicBlock *, 32>::const_iterator I = DefBlocks.begin(), - E = DefBlocks.end(); - I != E; ++I) { - if (DomTreeNode *Node = DT.getNode(*I)) + for (BasicBlock *BB : DefBlocks) { + if (DomTreeNode *Node = DT.getNode(BB)) PQ.push(std::make_pair(Node, DomLevels[Node])); } @@ -898,7 +902,7 @@ void PromoteMem2Reg::DetermineInsertionPoint(AllocaInst *AI, unsigned AllocaNum, if (SuccLevel > RootLevel) continue; - if (!Visited.insert(SuccNode)) + if (!Visited.insert(SuccNode).second) continue; BasicBlock *SuccBB = SuccNode->getBlock(); @@ -1003,7 +1007,7 @@ NextIteration: } // Don't revisit blocks. - if (!Visited.insert(BB)) + if (!Visited.insert(BB).second) return; for (BasicBlock::iterator II = BB->begin(); !isa<TerminatorInst>(II);) { @@ -1060,17 +1064,17 @@ NextIteration: ++I; for (; I != E; ++I) - if (VisitedSuccs.insert(*I)) + if (VisitedSuccs.insert(*I).second) Worklist.push_back(RenamePassData(*I, Pred, IncomingVals)); goto NextIteration; } void llvm::PromoteMemToReg(ArrayRef<AllocaInst *> Allocas, DominatorTree &DT, - AliasSetTracker *AST) { + AliasSetTracker *AST, AssumptionCache *AC) { // If there is nothing to do, bail out... if (Allocas.empty()) return; - PromoteMem2Reg(Allocas, DT, AST).run(); + PromoteMem2Reg(Allocas, DT, AST, AC).run(); } diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp index 24bb63bb60a5..f6867c24ee58 100644 --- a/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/lib/Transforms/Utils/SimplifyCFG.cpp @@ -43,6 +43,8 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> #include <map> #include <set> @@ -68,12 +70,24 @@ static cl::opt<bool> HoistCondStores( cl::desc("Hoist conditional stores if an unconditional store precedes")); STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps"); +STATISTIC(NumLinearMaps, "Number of switch instructions turned into linear mapping"); STATISTIC(NumLookupTables, "Number of switch instructions turned into lookup tables"); STATISTIC(NumLookupTablesHoles, "Number of switch instructions turned into lookup tables (holes checked)"); +STATISTIC(NumTableCmpReuses, "Number of reused switch table lookup compares"); STATISTIC(NumSinkCommons, "Number of common instructions sunk down to the end block"); STATISTIC(NumSpeculations, "Number of speculative executed instructions"); namespace { + // The first field contains the value that the switch produces when a certain + // case group is selected, and the second field is a vector containing the cases + // composing the case group. + typedef SmallVector<std::pair<Constant *, SmallVector<ConstantInt *, 4>>, 2> + SwitchCaseResultVectorTy; + // The first field contains the phi node that generates a result of the switch + // and the second field contains the value generated for a certain case in the switch + // for that PHI. + typedef SmallVector<std::pair<PHINode *, Constant *>, 4> SwitchCaseResultsTy; + /// ValueEqualityComparisonCase - Represents a case of a switch. struct ValueEqualityComparisonCase { ConstantInt *Value; @@ -92,7 +106,9 @@ namespace { class SimplifyCFGOpt { const TargetTransformInfo &TTI; + unsigned BonusInstThreshold; const DataLayout *const DL; + AssumptionCache *AC; Value *isValueEqualityComparison(TerminatorInst *TI); BasicBlock *GetValueEqualityComparisonCases(TerminatorInst *TI, std::vector<ValueEqualityComparisonCase> &Cases); @@ -111,8 +127,9 @@ class SimplifyCFGOpt { bool SimplifyCondBranch(BranchInst *BI, IRBuilder <>&Builder); public: - SimplifyCFGOpt(const TargetTransformInfo &TTI, const DataLayout *DL) - : TTI(TTI), DL(DL) {} + SimplifyCFGOpt(const TargetTransformInfo &TTI, unsigned BonusInstThreshold, + const DataLayout *DL, AssumptionCache *AC) + : TTI(TTI), BonusInstThreshold(BonusInstThreshold), DL(DL), AC(AC) {} bool run(BasicBlock *BB); }; } @@ -256,7 +273,7 @@ static unsigned ComputeSpeculationCost(const User *I, const DataLayout *DL) { /// V plus its non-dominating operands. If that cost is greater than /// CostRemaining, false is returned and CostRemaining is undefined. static bool DominatesMergePoint(Value *V, BasicBlock *BB, - SmallPtrSet<Instruction*, 4> *AggressiveInsts, + SmallPtrSetImpl<Instruction*> *AggressiveInsts, unsigned &CostRemaining, const DataLayout *DL) { Instruction *I = dyn_cast<Instruction>(V); @@ -341,114 +358,177 @@ static ConstantInt *GetConstantInt(Value *V, const DataLayout *DL) { return nullptr; } -/// GatherConstantCompares - Given a potentially 'or'd or 'and'd together -/// collection of icmp eq/ne instructions that compare a value against a -/// constant, return the value being compared, and stick the constant into the -/// Values vector. -static Value * -GatherConstantCompares(Value *V, std::vector<ConstantInt*> &Vals, Value *&Extra, - const DataLayout *DL, bool isEQ, unsigned &UsedICmps) { - Instruction *I = dyn_cast<Instruction>(V); - if (!I) return nullptr; - - // If this is an icmp against a constant, handle this as one of the cases. - if (ICmpInst *ICI = dyn_cast<ICmpInst>(I)) { - if (ConstantInt *C = GetConstantInt(I->getOperand(1), DL)) { - Value *RHSVal; - ConstantInt *RHSC; - - if (ICI->getPredicate() == (isEQ ? ICmpInst::ICMP_EQ:ICmpInst::ICMP_NE)) { - // (x & ~2^x) == y --> x == y || x == y|2^x - // This undoes a transformation done by instcombine to fuse 2 compares. - if (match(ICI->getOperand(0), - m_And(m_Value(RHSVal), m_ConstantInt(RHSC)))) { - APInt Not = ~RHSC->getValue(); - if (Not.isPowerOf2()) { - Vals.push_back(C); - Vals.push_back( - ConstantInt::get(C->getContext(), C->getValue() | Not)); - UsedICmps++; - return RHSVal; - } - } +namespace { + +/// Given a chain of or (||) or and (&&) comparison of a value against a +/// constant, this will try to recover the information required for a switch +/// structure. +/// It will depth-first traverse the chain of comparison, seeking for patterns +/// like %a == 12 or %a < 4 and combine them to produce a set of integer +/// representing the different cases for the switch. +/// Note that if the chain is composed of '||' it will build the set of elements +/// that matches the comparisons (i.e. any of this value validate the chain) +/// while for a chain of '&&' it will build the set elements that make the test +/// fail. +struct ConstantComparesGatherer { + + Value *CompValue; /// Value found for the switch comparison + Value *Extra; /// Extra clause to be checked before the switch + SmallVector<ConstantInt *, 8> Vals; /// Set of integers to match in switch + unsigned UsedICmps; /// Number of comparisons matched in the and/or chain + + /// Construct and compute the result for the comparison instruction Cond + ConstantComparesGatherer(Instruction *Cond, const DataLayout *DL) + : CompValue(nullptr), Extra(nullptr), UsedICmps(0) { + gather(Cond, DL); + } + + /// Prevent copy + ConstantComparesGatherer(const ConstantComparesGatherer &) + LLVM_DELETED_FUNCTION; + ConstantComparesGatherer & + operator=(const ConstantComparesGatherer &) LLVM_DELETED_FUNCTION; + +private: - UsedICmps++; - Vals.push_back(C); - return I->getOperand(0); + /// Try to set the current value used for the comparison, it succeeds only if + /// it wasn't set before or if the new value is the same as the old one + bool setValueOnce(Value *NewVal) { + if(CompValue && CompValue != NewVal) return false; + CompValue = NewVal; + return (CompValue != nullptr); + } + + /// Try to match Instruction "I" as a comparison against a constant and + /// populates the array Vals with the set of values that match (or do not + /// match depending on isEQ). + /// Return false on failure. On success, the Value the comparison matched + /// against is placed in CompValue. + /// If CompValue is already set, the function is expected to fail if a match + /// is found but the value compared to is different. + bool matchInstruction(Instruction *I, const DataLayout *DL, bool isEQ) { + // If this is an icmp against a constant, handle this as one of the cases. + ICmpInst *ICI; + ConstantInt *C; + if (!((ICI = dyn_cast<ICmpInst>(I)) && + (C = GetConstantInt(I->getOperand(1), DL)))) { + return false; + } + + Value *RHSVal; + ConstantInt *RHSC; + + // Pattern match a special case + // (x & ~2^x) == y --> x == y || x == y|2^x + // This undoes a transformation done by instcombine to fuse 2 compares. + if (ICI->getPredicate() == (isEQ ? ICmpInst::ICMP_EQ:ICmpInst::ICMP_NE)) { + if (match(ICI->getOperand(0), + m_And(m_Value(RHSVal), m_ConstantInt(RHSC)))) { + APInt Not = ~RHSC->getValue(); + if (Not.isPowerOf2()) { + // If we already have a value for the switch, it has to match! + if(!setValueOnce(RHSVal)) + return false; + + Vals.push_back(C); + Vals.push_back(ConstantInt::get(C->getContext(), + C->getValue() | Not)); + UsedICmps++; + return true; + } } - // If we have "x ult 3" comparison, for example, then we can add 0,1,2 to - // the set. - ConstantRange Span = - ConstantRange::makeICmpRegion(ICI->getPredicate(), C->getValue()); - - // Shift the range if the compare is fed by an add. This is the range - // compare idiom as emitted by instcombine. - bool hasAdd = - match(I->getOperand(0), m_Add(m_Value(RHSVal), m_ConstantInt(RHSC))); - if (hasAdd) - Span = Span.subtract(RHSC->getValue()); - - // If this is an and/!= check then we want to optimize "x ugt 2" into - // x != 0 && x != 1. - if (!isEQ) - Span = Span.inverse(); - - // If there are a ton of values, we don't want to make a ginormous switch. - if (Span.getSetSize().ugt(8) || Span.isEmptySet()) - return nullptr; - - for (APInt Tmp = Span.getLower(); Tmp != Span.getUpper(); ++Tmp) - Vals.push_back(ConstantInt::get(V->getContext(), Tmp)); + // If we already have a value for the switch, it has to match! + if(!setValueOnce(ICI->getOperand(0))) + return false; + UsedICmps++; - return hasAdd ? RHSVal : I->getOperand(0); + Vals.push_back(C); + return ICI->getOperand(0); } - return nullptr; - } - // Otherwise, we can only handle an | or &, depending on isEQ. - if (I->getOpcode() != (isEQ ? Instruction::Or : Instruction::And)) - return nullptr; + // If we have "x ult 3", for example, then we can add 0,1,2 to the set. + ConstantRange Span = ConstantRange::makeICmpRegion(ICI->getPredicate(), + C->getValue()); - unsigned NumValsBeforeLHS = Vals.size(); - unsigned UsedICmpsBeforeLHS = UsedICmps; - if (Value *LHS = GatherConstantCompares(I->getOperand(0), Vals, Extra, DL, - isEQ, UsedICmps)) { - unsigned NumVals = Vals.size(); - unsigned UsedICmpsBeforeRHS = UsedICmps; - if (Value *RHS = GatherConstantCompares(I->getOperand(1), Vals, Extra, DL, - isEQ, UsedICmps)) { - if (LHS == RHS) - return LHS; - Vals.resize(NumVals); - UsedICmps = UsedICmpsBeforeRHS; + // Shift the range if the compare is fed by an add. This is the range + // compare idiom as emitted by instcombine. + Value *CandidateVal = I->getOperand(0); + if(match(I->getOperand(0), m_Add(m_Value(RHSVal), m_ConstantInt(RHSC)))) { + Span = Span.subtract(RHSC->getValue()); + CandidateVal = RHSVal; } - // The RHS of the or/and can't be folded in and we haven't used "Extra" yet, - // set it and return success. - if (Extra == nullptr || Extra == I->getOperand(1)) { - Extra = I->getOperand(1); - return LHS; + // If this is an and/!= check, then we are looking to build the set of + // value that *don't* pass the and chain. I.e. to turn "x ugt 2" into + // x != 0 && x != 1. + if (!isEQ) + Span = Span.inverse(); + + // If there are a ton of values, we don't want to make a ginormous switch. + if (Span.getSetSize().ugt(8) || Span.isEmptySet()) { + return false; } - Vals.resize(NumValsBeforeLHS); - UsedICmps = UsedICmpsBeforeLHS; - return nullptr; + // If we already have a value for the switch, it has to match! + if(!setValueOnce(CandidateVal)) + return false; + + // Add all values from the range to the set + for (APInt Tmp = Span.getLower(); Tmp != Span.getUpper(); ++Tmp) + Vals.push_back(ConstantInt::get(I->getContext(), Tmp)); + + UsedICmps++; + return true; + } - // If the LHS can't be folded in, but Extra is available and RHS can, try to - // use LHS as Extra. - if (Extra == nullptr || Extra == I->getOperand(0)) { - Value *OldExtra = Extra; - Extra = I->getOperand(0); - if (Value *RHS = GatherConstantCompares(I->getOperand(1), Vals, Extra, DL, - isEQ, UsedICmps)) - return RHS; - assert(Vals.size() == NumValsBeforeLHS); - Extra = OldExtra; + /// gather - Given a potentially 'or'd or 'and'd together collection of icmp + /// eq/ne/lt/gt instructions that compare a value against a constant, extract + /// the value being compared, and stick the list constants into the Vals + /// vector. + /// One "Extra" case is allowed to differ from the other. + void gather(Value *V, const DataLayout *DL) { + Instruction *I = dyn_cast<Instruction>(V); + bool isEQ = (I->getOpcode() == Instruction::Or); + + // Keep a stack (SmallVector for efficiency) for depth-first traversal + SmallVector<Value *, 8> DFT; + + // Initialize + DFT.push_back(V); + + while(!DFT.empty()) { + V = DFT.pop_back_val(); + + if (Instruction *I = dyn_cast<Instruction>(V)) { + // If it is a || (or && depending on isEQ), process the operands. + if (I->getOpcode() == (isEQ ? Instruction::Or : Instruction::And)) { + DFT.push_back(I->getOperand(1)); + DFT.push_back(I->getOperand(0)); + continue; + } + + // Try to match the current instruction + if (matchInstruction(I, DL, isEQ)) + // Match succeed, continue the loop + continue; + } + + // One element of the sequence of || (or &&) could not be match as a + // comparison against the same value as the others. + // We allow only one "Extra" case to be checked before the switch + if (!Extra) { + Extra = V; + continue; + } + // Failed to parse a proper sequence, abort now + CompValue = nullptr; + break; + } } +}; - return nullptr; } static void EraseTerminatorInstAndDCECond(TerminatorInst *TI) { @@ -628,13 +708,12 @@ SimplifyEqualityComparisonWithOnlyPredecessor(TerminatorInst *TI, // Collect branch weights into a vector. SmallVector<uint32_t, 8> Weights; - MDNode* MD = SI->getMetadata(LLVMContext::MD_prof); + MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); bool HasWeight = MD && (MD->getNumOperands() == 2 + SI->getNumCases()); if (HasWeight) for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e; ++MD_i) { - ConstantInt* CI = dyn_cast<ConstantInt>(MD->getOperand(MD_i)); - assert(CI); + ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(MD_i)); Weights.push_back(CI->getValue().getZExtValue()); } for (SwitchInst::CaseIt i = SI->case_end(), e = SI->case_begin(); i != e;) { @@ -723,7 +802,7 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1, } static inline bool HasBranchWeights(const Instruction* I) { - MDNode* ProfMD = I->getMetadata(LLVMContext::MD_prof); + MDNode *ProfMD = I->getMetadata(LLVMContext::MD_prof); if (ProfMD && ProfMD->getOperand(0)) if (MDString* MDS = dyn_cast<MDString>(ProfMD->getOperand(0))) return MDS->getString().equals("branch_weights"); @@ -736,10 +815,10 @@ static inline bool HasBranchWeights(const Instruction* I) { /// metadata. static void GetBranchWeights(TerminatorInst *TI, SmallVectorImpl<uint64_t> &Weights) { - MDNode* MD = TI->getMetadata(LLVMContext::MD_prof); + MDNode *MD = TI->getMetadata(LLVMContext::MD_prof); assert(MD); for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) { - ConstantInt *CI = cast<ConstantInt>(MD->getOperand(i)); + ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(i)); Weights.push_back(CI->getValue().getZExtValue()); } @@ -995,6 +1074,8 @@ static bool isSafeToHoistInvoke(BasicBlock *BB1, BasicBlock *BB2, return true; } +static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I); + /// HoistThenElseCodeToIf - Given a conditional branch that goes to BB1 and /// BB2, hoist any common code in the two blocks up into the branch block. The /// caller of this function guarantees that BI's block dominates BB1 and BB2. @@ -1040,6 +1121,14 @@ static bool HoistThenElseCodeToIf(BranchInst *BI, const DataLayout *DL) { if (!I2->use_empty()) I2->replaceAllUsesWith(I1); I1->intersectOptionalDataWith(I2); + unsigned KnownIDs[] = { + LLVMContext::MD_tbaa, + LLVMContext::MD_range, + LLVMContext::MD_fpmath, + LLVMContext::MD_invariant_load, + LLVMContext::MD_nonnull + }; + combineMetadata(I1, I2, KnownIDs); I2->eraseFromParent(); Changed = true; @@ -1072,6 +1161,12 @@ HoistTerminator: if (BB1V == BB2V) continue; + // Check for passingValueIsAlwaysUndefined here because we would rather + // eliminate undefined control flow then converting it to a select. + if (passingValueIsAlwaysUndefined(BB1V, PN) || + passingValueIsAlwaysUndefined(BB2V, PN)) + return Changed; + if (isa<ConstantExpr>(BB1V) && !isSafeToSpeculativelyExecute(BB1V, DL)) return Changed; if (isa<ConstantExpr>(BB2V) && !isSafeToSpeculativelyExecute(BB2V, DL)) @@ -1149,14 +1244,13 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { return false; // Gather the PHI nodes in BBEnd. - std::map<Value*, std::pair<Value*, PHINode*> > MapValueFromBB1ToBB2; + SmallDenseMap<std::pair<Value *, Value *>, PHINode *> JointValueMap; Instruction *FirstNonPhiInBBEnd = nullptr; - for (BasicBlock::iterator I = BBEnd->begin(), E = BBEnd->end(); - I != E; ++I) { + for (BasicBlock::iterator I = BBEnd->begin(), E = BBEnd->end(); I != E; ++I) { if (PHINode *PN = dyn_cast<PHINode>(I)) { Value *BB1V = PN->getIncomingValueForBlock(BB1); Value *BB2V = PN->getIncomingValueForBlock(BB2); - MapValueFromBB1ToBB2[BB1V] = std::make_pair(BB2V, PN); + JointValueMap[std::make_pair(BB1V, BB2V)] = PN; } else { FirstNonPhiInBBEnd = &*I; break; @@ -1165,13 +1259,13 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { if (!FirstNonPhiInBBEnd) return false; - // This does very trivial matching, with limited scanning, to find identical // instructions in the two blocks. We scan backward for obviously identical // instructions in an identical order. BasicBlock::InstListType::reverse_iterator RI1 = BB1->getInstList().rbegin(), - RE1 = BB1->getInstList().rend(), RI2 = BB2->getInstList().rbegin(), - RE2 = BB2->getInstList().rend(); + RE1 = BB1->getInstList().rend(), + RI2 = BB2->getInstList().rbegin(), + RE2 = BB2->getInstList().rend(); // Skip debug info. while (RI1 != RE1 && isa<DbgInfoIntrinsic>(&*RI1)) ++RI1; if (RI1 == RE1) @@ -1194,6 +1288,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { return Changed; Instruction *I1 = &*RI1, *I2 = &*RI2; + auto InstPair = std::make_pair(I1, I2); // I1 and I2 should have a single use in the same PHI node, and they // perform the same operation. // Cannot move control-flow-involving, volatile loads, vaarg, etc. @@ -1204,11 +1299,11 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { I1->mayHaveSideEffects() || I2->mayHaveSideEffects() || I1->mayReadOrWriteMemory() || I2->mayReadOrWriteMemory() || !I1->hasOneUse() || !I2->hasOneUse() || - MapValueFromBB1ToBB2.find(I1) == MapValueFromBB1ToBB2.end() || - MapValueFromBB1ToBB2[I1].first != I2) + !JointValueMap.count(InstPair)) return Changed; // Check whether we should swap the operands of ICmpInst. + // TODO: Add support of communativity. ICmpInst *ICmp1 = dyn_cast<ICmpInst>(I1), *ICmp2 = dyn_cast<ICmpInst>(I2); bool SwapOpnds = false; if (ICmp1 && ICmp2 && @@ -1229,16 +1324,13 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { // with a PHI node after sinking. We only handle the case where there is // a single pair of different operands. Value *DifferentOp1 = nullptr, *DifferentOp2 = nullptr; - unsigned Op1Idx = 0; + unsigned Op1Idx = ~0U; for (unsigned I = 0, E = I1->getNumOperands(); I != E; ++I) { if (I1->getOperand(I) == I2->getOperand(I)) continue; - // Early exit if we have more-than one pair of different operands or - // the different operand is already in MapValueFromBB1ToBB2. - // Early exit if we need a PHI node to replace a constant. - if (DifferentOp1 || - MapValueFromBB1ToBB2.find(I1->getOperand(I)) != - MapValueFromBB1ToBB2.end() || + // Early exit if we have more-than one pair of different operands or if + // we need a PHI node to replace a constant. + if (Op1Idx != ~0U || isa<Constant>(I1->getOperand(I)) || isa<Constant>(I2->getOperand(I))) { // If we can't sink the instructions, undo the swapping. @@ -1251,24 +1343,27 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { DifferentOp2 = I2->getOperand(I); } - // We insert the pair of different operands to MapValueFromBB1ToBB2 and - // remove (I1, I2) from MapValueFromBB1ToBB2. - if (DifferentOp1) { - PHINode *NewPN = PHINode::Create(DifferentOp1->getType(), 2, - DifferentOp1->getName() + ".sink", - BBEnd->begin()); - MapValueFromBB1ToBB2[DifferentOp1] = std::make_pair(DifferentOp2, NewPN); + DEBUG(dbgs() << "SINK common instructions " << *I1 << "\n"); + DEBUG(dbgs() << " " << *I2 << "\n"); + + // We insert the pair of different operands to JointValueMap and + // remove (I1, I2) from JointValueMap. + if (Op1Idx != ~0U) { + auto &NewPN = JointValueMap[std::make_pair(DifferentOp1, DifferentOp2)]; + if (!NewPN) { + NewPN = + PHINode::Create(DifferentOp1->getType(), 2, + DifferentOp1->getName() + ".sink", BBEnd->begin()); + NewPN->addIncoming(DifferentOp1, BB1); + NewPN->addIncoming(DifferentOp2, BB2); + DEBUG(dbgs() << "Create PHI node " << *NewPN << "\n";); + } // I1 should use NewPN instead of DifferentOp1. I1->setOperand(Op1Idx, NewPN); - NewPN->addIncoming(DifferentOp1, BB1); - NewPN->addIncoming(DifferentOp2, BB2); - DEBUG(dbgs() << "Create PHI node " << *NewPN << "\n";); } - PHINode *OldPN = MapValueFromBB1ToBB2[I1].second; - MapValueFromBB1ToBB2.erase(I1); + PHINode *OldPN = JointValueMap[InstPair]; + JointValueMap.erase(InstPair); - DEBUG(dbgs() << "SINK common instructions " << *I1 << "\n";); - DEBUG(dbgs() << " " << *I2 << "\n";); // We need to update RE1 and RE2 if we are going to sink the first // instruction in the basic block down. bool UpdateRE1 = (I1 == BB1->begin()), UpdateRE2 = (I2 == BB2->begin()); @@ -1281,6 +1376,8 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { if (!I2->use_empty()) I2->replaceAllUsesWith(I1); I1->intersectOptionalDataWith(I2); + // TODO: Use combineMetadata here to preserve what metadata we can + // (analogous to the hoisting case above). I2->eraseFromParent(); if (UpdateRE1) @@ -1486,6 +1583,11 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, if (ThenV == OrigV) continue; + // Don't convert to selects if we could remove undefined behavior instead. + if (passingValueIsAlwaysUndefined(OrigV, PN) || + passingValueIsAlwaysUndefined(ThenV, PN)) + return false; + HaveRewritablePHIs = true; ConstantExpr *OrigCE = dyn_cast<ConstantExpr>(OrigV); ConstantExpr *ThenCE = dyn_cast<ConstantExpr>(ThenV); @@ -1934,8 +2036,10 @@ static bool ExtractBranchMetadata(BranchInst *BI, "Looking for probabilities on unconditional branch?"); MDNode *ProfileData = BI->getMetadata(LLVMContext::MD_prof); if (!ProfileData || ProfileData->getNumOperands() != 3) return false; - ConstantInt *CITrue = dyn_cast<ConstantInt>(ProfileData->getOperand(1)); - ConstantInt *CIFalse = dyn_cast<ConstantInt>(ProfileData->getOperand(2)); + ConstantInt *CITrue = + mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1)); + ConstantInt *CIFalse = + mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2)); if (!CITrue || !CIFalse) return false; ProbTrue = CITrue->getValue().getZExtValue(); ProbFalse = CIFalse->getValue().getZExtValue(); @@ -1963,7 +2067,8 @@ static bool checkCSEInPredecessor(Instruction *Inst, BasicBlock *PB) { /// FoldBranchToCommonDest - If this basic block is simple enough, and if a /// predecessor branches to us and one of our successors, fold the block into /// the predecessor and use logical operations to pick the right destination. -bool llvm::FoldBranchToCommonDest(BranchInst *BI, const DataLayout *DL) { +bool llvm::FoldBranchToCommonDest(BranchInst *BI, const DataLayout *DL, + unsigned BonusInstThreshold) { BasicBlock *BB = BI->getParent(); Instruction *Cond = nullptr; @@ -2000,33 +2105,6 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, const DataLayout *DL) { Cond->getParent() != BB || !Cond->hasOneUse()) return false; - // Only allow this if the condition is a simple instruction that can be - // executed unconditionally. It must be in the same block as the branch, and - // must be at the front of the block. - BasicBlock::iterator FrontIt = BB->front(); - - // Ignore dbg intrinsics. - while (isa<DbgInfoIntrinsic>(FrontIt)) ++FrontIt; - - // Allow a single instruction to be hoisted in addition to the compare - // that feeds the branch. We later ensure that any values that _it_ uses - // were also live in the predecessor, so that we don't unnecessarily create - // register pressure or inhibit out-of-order execution. - Instruction *BonusInst = nullptr; - if (&*FrontIt != Cond && - FrontIt->hasOneUse() && FrontIt->user_back() == Cond && - isSafeToSpeculativelyExecute(FrontIt, DL)) { - BonusInst = &*FrontIt; - ++FrontIt; - - // Ignore dbg intrinsics. - while (isa<DbgInfoIntrinsic>(FrontIt)) ++FrontIt; - } - - // Only a single bonus inst is allowed. - if (&*FrontIt != Cond) - return false; - // Make sure the instruction after the condition is the cond branch. BasicBlock::iterator CondIt = Cond; ++CondIt; @@ -2036,6 +2114,31 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, const DataLayout *DL) { if (&*CondIt != BI) return false; + // Only allow this transformation if computing the condition doesn't involve + // too many instructions and these involved instructions can be executed + // unconditionally. We denote all involved instructions except the condition + // as "bonus instructions", and only allow this transformation when the + // number of the bonus instructions does not exceed a certain threshold. + unsigned NumBonusInsts = 0; + for (auto I = BB->begin(); Cond != I; ++I) { + // Ignore dbg intrinsics. + if (isa<DbgInfoIntrinsic>(I)) + continue; + if (!I->hasOneUse() || !isSafeToSpeculativelyExecute(I, DL)) + return false; + // I has only one use and can be executed unconditionally. + Instruction *User = dyn_cast<Instruction>(I->user_back()); + if (User == nullptr || User->getParent() != BB) + return false; + // I is used in the same BB. Since BI uses Cond and doesn't have more slots + // to use any other instruction, User must be an instruction between next(I) + // and Cond. + ++NumBonusInsts; + // Early exits once we reach the limit. + if (NumBonusInsts > BonusInstThreshold) + return false; + } + // Cond is known to be a compare or binary operator. Check to make sure that // neither operand is a potentially-trapping constant expression. if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Cond->getOperand(0))) @@ -2086,49 +2189,6 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, const DataLayout *DL) { continue; } - // Ensure that any values used in the bonus instruction are also used - // by the terminator of the predecessor. This means that those values - // must already have been resolved, so we won't be inhibiting the - // out-of-order core by speculating them earlier. We also allow - // instructions that are used by the terminator's condition because it - // exposes more merging opportunities. - bool UsedByBranch = (BonusInst && BonusInst->hasOneUse() && - BonusInst->user_back() == Cond); - - if (BonusInst && !UsedByBranch) { - // Collect the values used by the bonus inst - SmallPtrSet<Value*, 4> UsedValues; - for (Instruction::op_iterator OI = BonusInst->op_begin(), - OE = BonusInst->op_end(); OI != OE; ++OI) { - Value *V = *OI; - if (!isa<Constant>(V) && !isa<Argument>(V)) - UsedValues.insert(V); - } - - SmallVector<std::pair<Value*, unsigned>, 4> Worklist; - Worklist.push_back(std::make_pair(PBI->getOperand(0), 0)); - - // Walk up to four levels back up the use-def chain of the predecessor's - // terminator to see if all those values were used. The choice of four - // levels is arbitrary, to provide a compile-time-cost bound. - while (!Worklist.empty()) { - std::pair<Value*, unsigned> Pair = Worklist.back(); - Worklist.pop_back(); - - if (Pair.second >= 4) continue; - UsedValues.erase(Pair.first); - if (UsedValues.empty()) break; - - if (Instruction *I = dyn_cast<Instruction>(Pair.first)) { - for (Instruction::op_iterator OI = I->op_begin(), OE = I->op_end(); - OI != OE; ++OI) - Worklist.push_back(std::make_pair(OI->get(), Pair.second+1)); - } - } - - if (!UsedValues.empty()) return false; - } - DEBUG(dbgs() << "FOLDING BRANCH TO COMMON DEST:\n" << *PBI << *BB); IRBuilder<> Builder(PBI); @@ -2148,30 +2208,41 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, const DataLayout *DL) { PBI->swapSuccessors(); } - // If we have a bonus inst, clone it into the predecessor block. - Instruction *NewBonus = nullptr; - if (BonusInst) { - NewBonus = BonusInst->clone(); + // If we have bonus instructions, clone them into the predecessor block. + // Note that there may be mutliple predecessor blocks, so we cannot move + // bonus instructions to a predecessor block. + ValueToValueMapTy VMap; // maps original values to cloned values + // We already make sure Cond is the last instruction before BI. Therefore, + // every instructions before Cond other than DbgInfoIntrinsic are bonus + // instructions. + for (auto BonusInst = BB->begin(); Cond != BonusInst; ++BonusInst) { + if (isa<DbgInfoIntrinsic>(BonusInst)) + continue; + Instruction *NewBonusInst = BonusInst->clone(); + RemapInstruction(NewBonusInst, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingEntries); + VMap[BonusInst] = NewBonusInst; // If we moved a load, we cannot any longer claim any knowledge about // its potential value. The previous information might have been valid // only given the branch precondition. // For an analogous reason, we must also drop all the metadata whose // semantics we don't understand. - NewBonus->dropUnknownMetadata(LLVMContext::MD_dbg); + NewBonusInst->dropUnknownMetadata(LLVMContext::MD_dbg); - PredBlock->getInstList().insert(PBI, NewBonus); - NewBonus->takeName(BonusInst); - BonusInst->setName(BonusInst->getName()+".old"); + PredBlock->getInstList().insert(PBI, NewBonusInst); + NewBonusInst->takeName(BonusInst); + BonusInst->setName(BonusInst->getName() + ".old"); } // Clone Cond into the predecessor basic block, and or/and the // two conditions together. Instruction *New = Cond->clone(); - if (BonusInst) New->replaceUsesOfWith(BonusInst, NewBonus); + RemapInstruction(New, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingEntries); PredBlock->getInstList().insert(PBI, New); New->takeName(Cond); - Cond->setName(New->getName()+".old"); + Cond->setName(New->getName() + ".old"); if (BI->isConditional()) { Instruction *NewCond = @@ -2649,7 +2720,7 @@ static bool SimplifyIndirectBrOnSelect(IndirectBrInst *IBI, SelectInst *SI) { /// the PHI, merging the third icmp into the switch. static bool TryToSimplifyUncondBranchWithICmpInIt( ICmpInst *ICI, IRBuilder<> &Builder, const TargetTransformInfo &TTI, - const DataLayout *DL) { + unsigned BonusInstThreshold, const DataLayout *DL, AssumptionCache *AC) { BasicBlock *BB = ICI->getParent(); // If the block has any PHIs in it or the icmp has multiple uses, it is too @@ -2682,7 +2753,7 @@ static bool TryToSimplifyUncondBranchWithICmpInIt( ICI->eraseFromParent(); } // BB is now empty, so it is likely to simplify away. - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } // Ok, the block is reachable from the default dest. If the constant we're @@ -2698,7 +2769,7 @@ static bool TryToSimplifyUncondBranchWithICmpInIt( ICI->replaceAllUsesWith(V); ICI->eraseFromParent(); // BB is now empty, so it is likely to simplify away. - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } // The use of the icmp has to be in the 'end' block, by the only PHI node in @@ -2759,24 +2830,17 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, const DataLayout *DL, Instruction *Cond = dyn_cast<Instruction>(BI->getCondition()); if (!Cond) return false; - // Change br (X == 0 | X == 1), T, F into a switch instruction. // If this is a bunch of seteq's or'd together, or if it's a bunch of // 'setne's and'ed together, collect them. - Value *CompVal = nullptr; - std::vector<ConstantInt*> Values; - bool TrueWhenEqual = true; - Value *ExtraCase = nullptr; - unsigned UsedICmps = 0; - - if (Cond->getOpcode() == Instruction::Or) { - CompVal = GatherConstantCompares(Cond, Values, ExtraCase, DL, true, - UsedICmps); - } else if (Cond->getOpcode() == Instruction::And) { - CompVal = GatherConstantCompares(Cond, Values, ExtraCase, DL, false, - UsedICmps); - TrueWhenEqual = false; - } + + // Try to gather values from a chain of and/or to be turned into a switch + ConstantComparesGatherer ConstantCompare(Cond, DL); + // Unpack the result + SmallVectorImpl<ConstantInt*> &Values = ConstantCompare.Vals; + Value *CompVal = ConstantCompare.CompValue; + unsigned UsedICmps = ConstantCompare.UsedICmps; + Value *ExtraCase = ConstantCompare.Extra; // If we didn't have a multiply compared value, fail. if (!CompVal) return false; @@ -2785,6 +2849,8 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, const DataLayout *DL, if (UsedICmps <= 1) return false; + bool TrueWhenEqual = (Cond->getOpcode() == Instruction::Or); + // There might be duplicate constants in the list, which the switch // instruction can't handle, remove them now. array_pod_sort(Values.begin(), Values.end(), ConstantIntSortPredicate); @@ -2954,7 +3020,7 @@ bool SimplifyCFGOpt::SimplifyReturn(ReturnInst *RI, IRBuilder<> &Builder) { } // If we eliminated all predecessors of the block, delete the block now. - if (pred_begin(BB) == pred_end(BB)) + if (pred_empty(BB)) // We know there are no successors, so just nuke the block. BB->eraseFromParent(); @@ -3127,7 +3193,7 @@ bool SimplifyCFGOpt::SimplifyUnreachable(UnreachableInst *UI) { } // If this block is now dead, remove it. - if (pred_begin(BB) == pred_end(BB) && + if (pred_empty(BB) && BB != &BB->getParent()->getEntryBlock()) { // We know there are no successors, so just nuke the block. BB->eraseFromParent(); @@ -3208,11 +3274,12 @@ static bool TurnSwitchRangeIntoICmp(SwitchInst *SI, IRBuilder<> &Builder) { /// EliminateDeadSwitchCases - Compute masked bits for the condition of a switch /// and use it to remove dead cases. -static bool EliminateDeadSwitchCases(SwitchInst *SI) { +static bool EliminateDeadSwitchCases(SwitchInst *SI, const DataLayout *DL, + AssumptionCache *AC) { Value *Cond = SI->getCondition(); unsigned Bits = Cond->getType()->getIntegerBitWidth(); APInt KnownZero(Bits, 0), KnownOne(Bits, 0); - computeKnownBits(Cond, KnownZero, KnownOne); + computeKnownBits(Cond, KnownZero, KnownOne, DL, 0, AC, SI); // Gather dead cases. SmallVector<ConstantInt*, 8> DeadCases; @@ -3419,6 +3486,21 @@ GetCaseResults(SwitchInst *SI, continue; } else if (Constant *C = ConstantFold(I, ConstantPool, DL)) { // Instruction is side-effect free and constant. + + // If the instruction has uses outside this block or a phi node slot for + // the block, it is not safe to bypass the instruction since it would then + // no longer dominate all its uses. + for (auto &Use : I->uses()) { + User *User = Use.getUser(); + if (Instruction *I = dyn_cast<Instruction>(User)) + if (I->getParent() == CaseDest) + continue; + if (PHINode *Phi = dyn_cast<PHINode>(User)) + if (Phi->getIncomingBlock(Use) == CaseDest) + continue; + return false; + } + ConstantPool.insert(std::make_pair(I, C)); } else { break; @@ -3444,12 +3526,6 @@ GetCaseResults(SwitchInst *SI, if (!ConstVal) return false; - // Note: If the constant comes from constant-propagating the case value - // through the CaseDest basic block, it will be safe to remove the - // instructions in that block. They cannot be used (except in the phi nodes - // we visit) outside CaseDest, because that block does not dominate its - // successor. If it did, we would not be in this phi node. - // Be conservative about which kinds of constants we support. if (!ValidLookupTableConstant(ConstVal)) return false; @@ -3460,6 +3536,163 @@ GetCaseResults(SwitchInst *SI, return Res.size() > 0; } +// MapCaseToResult - Helper function used to +// add CaseVal to the list of cases that generate Result. +static void MapCaseToResult(ConstantInt *CaseVal, + SwitchCaseResultVectorTy &UniqueResults, + Constant *Result) { + for (auto &I : UniqueResults) { + if (I.first == Result) { + I.second.push_back(CaseVal); + return; + } + } + UniqueResults.push_back(std::make_pair(Result, + SmallVector<ConstantInt*, 4>(1, CaseVal))); +} + +// InitializeUniqueCases - Helper function that initializes a map containing +// results for the PHI node of the common destination block for a switch +// instruction. Returns false if multiple PHI nodes have been found or if +// there is not a common destination block for the switch. +static bool InitializeUniqueCases( + SwitchInst *SI, const DataLayout *DL, PHINode *&PHI, + BasicBlock *&CommonDest, + SwitchCaseResultVectorTy &UniqueResults, + Constant *&DefaultResult) { + for (auto &I : SI->cases()) { + ConstantInt *CaseVal = I.getCaseValue(); + + // Resulting value at phi nodes for this case value. + SwitchCaseResultsTy Results; + if (!GetCaseResults(SI, CaseVal, I.getCaseSuccessor(), &CommonDest, Results, + DL)) + return false; + + // Only one value per case is permitted + if (Results.size() > 1) + return false; + MapCaseToResult(CaseVal, UniqueResults, Results.begin()->second); + + // Check the PHI consistency. + if (!PHI) + PHI = Results[0].first; + else if (PHI != Results[0].first) + return false; + } + // Find the default result value. + SmallVector<std::pair<PHINode *, Constant *>, 1> DefaultResults; + BasicBlock *DefaultDest = SI->getDefaultDest(); + GetCaseResults(SI, nullptr, SI->getDefaultDest(), &CommonDest, DefaultResults, + DL); + // If the default value is not found abort unless the default destination + // is unreachable. + DefaultResult = + DefaultResults.size() == 1 ? DefaultResults.begin()->second : nullptr; + if ((!DefaultResult && + !isa<UnreachableInst>(DefaultDest->getFirstNonPHIOrDbg()))) + return false; + + return true; +} + +// ConvertTwoCaseSwitch - Helper function that checks if it is possible to +// transform a switch with only two cases (or two cases + default) +// that produces a result into a value select. +// Example: +// switch (a) { +// case 10: %0 = icmp eq i32 %a, 10 +// return 10; %1 = select i1 %0, i32 10, i32 4 +// case 20: ----> %2 = icmp eq i32 %a, 20 +// return 2; %3 = select i1 %2, i32 2, i32 %1 +// default: +// return 4; +// } +static Value * +ConvertTwoCaseSwitch(const SwitchCaseResultVectorTy &ResultVector, + Constant *DefaultResult, Value *Condition, + IRBuilder<> &Builder) { + assert(ResultVector.size() == 2 && + "We should have exactly two unique results at this point"); + // If we are selecting between only two cases transform into a simple + // select or a two-way select if default is possible. + if (ResultVector[0].second.size() == 1 && + ResultVector[1].second.size() == 1) { + ConstantInt *const FirstCase = ResultVector[0].second[0]; + ConstantInt *const SecondCase = ResultVector[1].second[0]; + + bool DefaultCanTrigger = DefaultResult; + Value *SelectValue = ResultVector[1].first; + if (DefaultCanTrigger) { + Value *const ValueCompare = + Builder.CreateICmpEQ(Condition, SecondCase, "switch.selectcmp"); + SelectValue = Builder.CreateSelect(ValueCompare, ResultVector[1].first, + DefaultResult, "switch.select"); + } + Value *const ValueCompare = + Builder.CreateICmpEQ(Condition, FirstCase, "switch.selectcmp"); + return Builder.CreateSelect(ValueCompare, ResultVector[0].first, SelectValue, + "switch.select"); + } + + return nullptr; +} + +// RemoveSwitchAfterSelectConversion - Helper function to cleanup a switch +// instruction that has been converted into a select, fixing up PHI nodes and +// basic blocks. +static void RemoveSwitchAfterSelectConversion(SwitchInst *SI, PHINode *PHI, + Value *SelectValue, + IRBuilder<> &Builder) { + BasicBlock *SelectBB = SI->getParent(); + while (PHI->getBasicBlockIndex(SelectBB) >= 0) + PHI->removeIncomingValue(SelectBB); + PHI->addIncoming(SelectValue, SelectBB); + + Builder.CreateBr(PHI->getParent()); + + // Remove the switch. + for (unsigned i = 0, e = SI->getNumSuccessors(); i < e; ++i) { + BasicBlock *Succ = SI->getSuccessor(i); + + if (Succ == PHI->getParent()) + continue; + Succ->removePredecessor(SelectBB); + } + SI->eraseFromParent(); +} + +/// SwitchToSelect - If the switch is only used to initialize one or more +/// phi nodes in a common successor block with only two different +/// constant values, replace the switch with select. +static bool SwitchToSelect(SwitchInst *SI, IRBuilder<> &Builder, + const DataLayout *DL, AssumptionCache *AC) { + Value *const Cond = SI->getCondition(); + PHINode *PHI = nullptr; + BasicBlock *CommonDest = nullptr; + Constant *DefaultResult; + SwitchCaseResultVectorTy UniqueResults; + // Collect all the cases that will deliver the same value from the switch. + if (!InitializeUniqueCases(SI, DL, PHI, CommonDest, UniqueResults, + DefaultResult)) + return false; + // Selects choose between maximum two values. + if (UniqueResults.size() != 2) + return false; + assert(PHI != nullptr && "PHI for value select not found"); + + Builder.SetInsertPoint(SI); + Value *SelectValue = ConvertTwoCaseSwitch( + UniqueResults, + DefaultResult, Cond, Builder); + if (SelectValue) { + RemoveSwitchAfterSelectConversion(SI, PHI, SelectValue, Builder); + return true; + } + // The switch couldn't be converted into a select. + return false; +} + namespace { /// SwitchLookupTable - This class represents a lookup table that can be used /// to replace a switch. @@ -3493,6 +3726,11 @@ namespace { // store that single value and return it for each lookup. SingleValueKind, + // For tables where there is a linear relationship between table index + // and values. We calculate the result with a simple multiplication + // and addition instead of a table lookup. + LinearMapKind, + // For small tables with integer elements, we can pack them into a bitmap // that fits into a target-legal register. Values are retrieved by // shift and mask operations. @@ -3510,6 +3748,10 @@ namespace { ConstantInt *BitMap; IntegerType *BitMapElementTy; + // For LinearMapKind, these are the constants used to derive the value. + ConstantInt *LinearOffset; + ConstantInt *LinearMultiplier; + // For ArrayKind, this is the array. GlobalVariable *Array; }; @@ -3522,7 +3764,7 @@ SwitchLookupTable::SwitchLookupTable(Module &M, Constant *DefaultValue, const DataLayout *DL) : SingleValue(nullptr), BitMap(nullptr), BitMapElementTy(nullptr), - Array(nullptr) { + LinearOffset(nullptr), LinearMultiplier(nullptr), Array(nullptr) { assert(Values.size() && "Can't build lookup table without values!"); assert(TableSize >= Values.size() && "Can't fit values in table!"); @@ -3567,6 +3809,43 @@ SwitchLookupTable::SwitchLookupTable(Module &M, return; } + // Check if we can derive the value with a linear transformation from the + // table index. + if (isa<IntegerType>(ValueType)) { + bool LinearMappingPossible = true; + APInt PrevVal; + APInt DistToPrev; + assert(TableSize >= 2 && "Should be a SingleValue table."); + // Check if there is the same distance between two consecutive values. + for (uint64_t I = 0; I < TableSize; ++I) { + ConstantInt *ConstVal = dyn_cast<ConstantInt>(TableContents[I]); + if (!ConstVal) { + // This is an undef. We could deal with it, but undefs in lookup tables + // are very seldom. It's probably not worth the additional complexity. + LinearMappingPossible = false; + break; + } + APInt Val = ConstVal->getValue(); + if (I != 0) { + APInt Dist = Val - PrevVal; + if (I == 1) { + DistToPrev = Dist; + } else if (Dist != DistToPrev) { + LinearMappingPossible = false; + break; + } + } + PrevVal = Val; + } + if (LinearMappingPossible) { + LinearOffset = cast<ConstantInt>(TableContents[0]); + LinearMultiplier = ConstantInt::get(M.getContext(), DistToPrev); + Kind = LinearMapKind; + ++NumLinearMaps; + return; + } + } + // If the type is integer and the table fits in a register, build a bitmap. if (WouldFitInRegister(DL, TableSize, ValueType)) { IntegerType *IT = cast<IntegerType>(ValueType); @@ -3602,6 +3881,16 @@ Value *SwitchLookupTable::BuildLookup(Value *Index, IRBuilder<> &Builder) { switch (Kind) { case SingleValueKind: return SingleValue; + case LinearMapKind: { + // Derive the result value from the input value. + Value *Result = Builder.CreateIntCast(Index, LinearMultiplier->getType(), + false, "switch.idx.cast"); + if (!LinearMultiplier->isOne()) + Result = Builder.CreateMul(Result, LinearMultiplier, "switch.idx.mult"); + if (!LinearOffset->isZero()) + Result = Builder.CreateAdd(Result, LinearOffset, "switch.offset"); + return Result; + } case BitMapKind: { // Type of the bitmap (e.g. i59). IntegerType *MapTy = BitMap->getType(); @@ -3673,9 +3962,8 @@ static bool ShouldBuildLookupTable(SwitchInst *SI, bool AllTablesFitInRegister = true; bool HasIllegalType = false; - for (SmallDenseMap<PHINode*, Type*>::const_iterator I = ResultTypes.begin(), - E = ResultTypes.end(); I != E; ++I) { - Type *Ty = I->second; + for (const auto &I : ResultTypes) { + Type *Ty = I.second; // Saturate this flag to true. HasIllegalType = HasIllegalType || !TTI.isTypeLegal(Ty); @@ -3705,6 +3993,89 @@ static bool ShouldBuildLookupTable(SwitchInst *SI, return SI->getNumCases() * 10 >= TableSize * 4; } +/// Try to reuse the switch table index compare. Following pattern: +/// \code +/// if (idx < tablesize) +/// r = table[idx]; // table does not contain default_value +/// else +/// r = default_value; +/// if (r != default_value) +/// ... +/// \endcode +/// Is optimized to: +/// \code +/// cond = idx < tablesize; +/// if (cond) +/// r = table[idx]; +/// else +/// r = default_value; +/// if (cond) +/// ... +/// \endcode +/// Jump threading will then eliminate the second if(cond). +static void reuseTableCompare(User *PhiUser, BasicBlock *PhiBlock, + BranchInst *RangeCheckBranch, Constant *DefaultValue, + const SmallVectorImpl<std::pair<ConstantInt*, Constant*> >& Values) { + + ICmpInst *CmpInst = dyn_cast<ICmpInst>(PhiUser); + if (!CmpInst) + return; + + // We require that the compare is in the same block as the phi so that jump + // threading can do its work afterwards. + if (CmpInst->getParent() != PhiBlock) + return; + + Constant *CmpOp1 = dyn_cast<Constant>(CmpInst->getOperand(1)); + if (!CmpOp1) + return; + + Value *RangeCmp = RangeCheckBranch->getCondition(); + Constant *TrueConst = ConstantInt::getTrue(RangeCmp->getType()); + Constant *FalseConst = ConstantInt::getFalse(RangeCmp->getType()); + + // Check if the compare with the default value is constant true or false. + Constant *DefaultConst = ConstantExpr::getICmp(CmpInst->getPredicate(), + DefaultValue, CmpOp1, true); + if (DefaultConst != TrueConst && DefaultConst != FalseConst) + return; + + // Check if the compare with the case values is distinct from the default + // compare result. + for (auto ValuePair : Values) { + Constant *CaseConst = ConstantExpr::getICmp(CmpInst->getPredicate(), + ValuePair.second, CmpOp1, true); + if (!CaseConst || CaseConst == DefaultConst) + return; + assert((CaseConst == TrueConst || CaseConst == FalseConst) && + "Expect true or false as compare result."); + } + + // Check if the branch instruction dominates the phi node. It's a simple + // dominance check, but sufficient for our needs. + // Although this check is invariant in the calling loops, it's better to do it + // at this late stage. Practically we do it at most once for a switch. + BasicBlock *BranchBlock = RangeCheckBranch->getParent(); + for (auto PI = pred_begin(PhiBlock), E = pred_end(PhiBlock); PI != E; ++PI) { + BasicBlock *Pred = *PI; + if (Pred != BranchBlock && Pred->getUniquePredecessor() != BranchBlock) + return; + } + + if (DefaultConst == FalseConst) { + // The compare yields the same result. We can replace it. + CmpInst->replaceAllUsesWith(RangeCmp); + ++NumTableCmpReuses; + } else { + // The compare yields the same result, just inverted. We can replace it. + Value *InvertedTableCmp = BinaryOperator::CreateXor(RangeCmp, + ConstantInt::get(RangeCmp->getType(), 1), "inverted.cmp", + RangeCheckBranch); + CmpInst->replaceAllUsesWith(InvertedTableCmp); + ++NumTableCmpReuses; + } +} + /// SwitchToLookupTable - If the switch is only used to initialize one or more /// phi nodes in a common successor block with different constant values, /// replace the switch with lookup tables. @@ -3759,16 +4130,17 @@ static bool SwitchToLookupTable(SwitchInst *SI, return false; // Append the result from this case to the list for each phi. - for (ResultsTy::iterator I = Results.begin(), E = Results.end(); I!=E; ++I) { - if (!ResultLists.count(I->first)) - PHIs.push_back(I->first); - ResultLists[I->first].push_back(std::make_pair(CaseVal, I->second)); + for (const auto &I : Results) { + PHINode *PHI = I.first; + Constant *Value = I.second; + if (!ResultLists.count(PHI)) + PHIs.push_back(PHI); + ResultLists[PHI].push_back(std::make_pair(CaseVal, Value)); } } // Keep track of the result types. - for (size_t I = 0, E = PHIs.size(); I != E; ++I) { - PHINode *PHI = PHIs[I]; + for (PHINode *PHI : PHIs) { ResultTypes[PHI] = ResultLists[PHI][0].second->getType(); } @@ -3780,11 +4152,9 @@ static bool SwitchToLookupTable(SwitchInst *SI, // If the table has holes, we need a constant result for the default case // or a bitmask that fits in a register. SmallVector<std::pair<PHINode*, Constant*>, 4> DefaultResultsList; - bool HasDefaultResults = false; - if (TableHasHoles) { - HasDefaultResults = GetCaseResults(SI, nullptr, SI->getDefaultDest(), + bool HasDefaultResults = GetCaseResults(SI, nullptr, SI->getDefaultDest(), &CommonDest, DefaultResultsList, DL); - } + bool NeedMask = (TableHasHoles && !HasDefaultResults); if (NeedMask) { // As an extra penalty for the validity test we require more cases. @@ -3794,9 +4164,9 @@ static bool SwitchToLookupTable(SwitchInst *SI, return false; } - for (size_t I = 0, E = DefaultResultsList.size(); I != E; ++I) { - PHINode *PHI = DefaultResultsList[I].first; - Constant *Result = DefaultResultsList[I].second; + for (const auto &I : DefaultResultsList) { + PHINode *PHI = I.first; + Constant *Result = I.second; DefaultResults[PHI] = Result; } @@ -3827,14 +4197,19 @@ static bool SwitchToLookupTable(SwitchInst *SI, // lookup table BB. Otherwise, check if the condition value is within the case // range. If it is so, branch to the new BB. Otherwise branch to SI's default // destination. + BranchInst *RangeCheckBranch = nullptr; + const bool GeneratingCoveredLookupTable = MaxTableSize == TableSize; if (GeneratingCoveredLookupTable) { Builder.CreateBr(LookupBB); - SI->getDefaultDest()->removePredecessor(SI->getParent()); + // We cached PHINodes in PHIs, to avoid accessing deleted PHINodes later, + // do not delete PHINodes here. + SI->getDefaultDest()->removePredecessor(SI->getParent(), + true/*DontDeleteUselessPHIs*/); } else { Value *Cmp = Builder.CreateICmpULT(TableIndex, ConstantInt::get( MinCaseVal->getType(), TableSize)); - Builder.CreateCondBr(Cmp, LookupBB, SI->getDefaultDest()); + RangeCheckBranch = Builder.CreateCondBr(Cmp, LookupBB, SI->getDefaultDest()); } // Populate the BB that does the lookups. @@ -3851,9 +4226,12 @@ static bool SwitchToLookupTable(SwitchInst *SI, CommonDest->getParent(), CommonDest); + // Make the mask's bitwidth at least 8bit and a power-of-2 to avoid + // unnecessary illegal types. + uint64_t TableSizePowOf2 = NextPowerOf2(std::max(7ULL, TableSize - 1ULL)); + APInt MaskInt(TableSizePowOf2, 0); + APInt One(TableSizePowOf2, 1); // Build bitmask; fill in a 1 bit for every case. - APInt MaskInt(TableSize, 0); - APInt One(TableSize, 1); const ResultListTy &ResultList = ResultLists[PHIs[0]]; for (size_t I = 0, E = ResultList.size(); I != E; ++I) { uint64_t Idx = (ResultList[I].first->getValue() - @@ -3882,11 +4260,11 @@ static bool SwitchToLookupTable(SwitchInst *SI, bool ReturnedEarly = false; for (size_t I = 0, E = PHIs.size(); I != E; ++I) { PHINode *PHI = PHIs[I]; + const ResultListTy &ResultList = ResultLists[PHI]; // If using a bitmask, use any value to fill the lookup table holes. Constant *DV = NeedMask ? ResultLists[PHI][0].second : DefaultResults[PHI]; - SwitchLookupTable Table(Mod, TableSize, MinCaseVal, ResultLists[PHI], - DV, DL); + SwitchLookupTable Table(Mod, TableSize, MinCaseVal, ResultList, DV, DL); Value *Result = Table.BuildLookup(TableIndex, Builder); @@ -3899,6 +4277,16 @@ static bool SwitchToLookupTable(SwitchInst *SI, break; } + // Do a small peephole optimization: re-use the switch table compare if + // possible. + if (!TableHasHoles && HasDefaultResults && RangeCheckBranch) { + BasicBlock *PhiBlock = PHI->getParent(); + // Search for compare instructions which use the phi. + for (auto *User : PHI->users()) { + reuseTableCompare(User, PhiBlock, RangeCheckBranch, DV, ResultList); + } + } + PHI->addIncoming(Result, LookupBB); } @@ -3929,12 +4317,12 @@ bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { // see if that predecessor totally determines the outcome of this switch. if (BasicBlock *OnlyPred = BB->getSinglePredecessor()) if (SimplifyEqualityComparisonWithOnlyPredecessor(SI, OnlyPred, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; Value *Cond = SI->getCondition(); if (SelectInst *Select = dyn_cast<SelectInst>(Cond)) if (SimplifySwitchOnSelect(SI, Select)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; // If the block only contains the switch, see if we can fold the block // away into any preds. @@ -3944,22 +4332,25 @@ bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { ++BBI; if (SI == &*BBI) if (FoldValueComparisonIntoPredecessors(SI, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } // Try to transform the switch into an icmp and a branch. if (TurnSwitchRangeIntoICmp(SI, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; // Remove unreachable cases. - if (EliminateDeadSwitchCases(SI)) - return SimplifyCFG(BB, TTI, DL) | true; + if (EliminateDeadSwitchCases(SI, DL, AC)) + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; + + if (SwitchToSelect(SI, Builder, DL, AC)) + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; if (ForwardSwitchConditionToPHI(SI)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; if (SwitchToLookupTable(SI, Builder, TTI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; return false; } @@ -3972,7 +4363,7 @@ bool SimplifyCFGOpt::SimplifyIndirectBr(IndirectBrInst *IBI) { SmallPtrSet<Value *, 8> Succs; for (unsigned i = 0, e = IBI->getNumDestinations(); i != e; ++i) { BasicBlock *Dest = IBI->getDestination(i); - if (!Dest->hasAddressTaken() || !Succs.insert(Dest)) { + if (!Dest->hasAddressTaken() || !Succs.insert(Dest).second) { Dest->removePredecessor(BB); IBI->removeDestination(i); --i; --e; @@ -3996,7 +4387,7 @@ bool SimplifyCFGOpt::SimplifyIndirectBr(IndirectBrInst *IBI) { if (SelectInst *SI = dyn_cast<SelectInst>(IBI->getAddress())) { if (SimplifyIndirectBrOnSelect(IBI, SI)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } return Changed; } @@ -4008,7 +4399,7 @@ bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, IRBuilder<> &Builder){ return true; // If the Terminator is the only non-phi instruction, simplify the block. - BasicBlock::iterator I = BB->getFirstNonPHIOrDbgOrLifetime(); + BasicBlock::iterator I = BB->getFirstNonPHIOrDbg(); if (I->isTerminator() && BB != &BB->getParent()->getEntryBlock() && TryToSimplifyUncondBranchFromEmptyBlock(BB)) return true; @@ -4020,7 +4411,8 @@ bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, IRBuilder<> &Builder){ for (++I; isa<DbgInfoIntrinsic>(I); ++I) ; if (I->isTerminator() && - TryToSimplifyUncondBranchWithICmpInIt(ICI, Builder, TTI, DL)) + TryToSimplifyUncondBranchWithICmpInIt(ICI, Builder, TTI, + BonusInstThreshold, DL, AC)) return true; } @@ -4028,8 +4420,8 @@ bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, IRBuilder<> &Builder){ // branches to us and our successor, fold the comparison into the // predecessor and use logical operations to update the incoming value // for PHI nodes in common successor. - if (FoldBranchToCommonDest(BI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + if (FoldBranchToCommonDest(BI, DL, BonusInstThreshold)) + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; return false; } @@ -4044,7 +4436,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // switch. if (BasicBlock *OnlyPred = BB->getSinglePredecessor()) if (SimplifyEqualityComparisonWithOnlyPredecessor(BI, OnlyPred, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; // This block must be empty, except for the setcond inst, if it exists. // Ignore dbg intrinsics. @@ -4054,14 +4446,14 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { ++I; if (&*I == BI) { if (FoldValueComparisonIntoPredecessors(BI, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } else if (&*I == cast<Instruction>(BI->getCondition())){ ++I; // Ignore dbg intrinsics. while (isa<DbgInfoIntrinsic>(I)) ++I; if (&*I == BI && FoldValueComparisonIntoPredecessors(BI, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } } @@ -4072,8 +4464,8 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // If this basic block is ONLY a compare and a branch, and if a predecessor // branches to us and one of our successors, fold the comparison into the // predecessor and use logical operations to pick the right destination. - if (FoldBranchToCommonDest(BI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + if (FoldBranchToCommonDest(BI, DL, BonusInstThreshold)) + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; // We have a conditional branch to two blocks that are only reachable // from BI. We know that the condbr dominates the two blocks, so see if @@ -4082,7 +4474,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (BI->getSuccessor(0)->getSinglePredecessor()) { if (BI->getSuccessor(1)->getSinglePredecessor()) { if (HoistThenElseCodeToIf(BI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } else { // If Successor #1 has multiple preds, we may be able to conditionally // execute Successor #0 if it branches to Successor #1. @@ -4090,7 +4482,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (Succ0TI->getNumSuccessors() == 1 && Succ0TI->getSuccessor(0) == BI->getSuccessor(1)) if (SpeculativelyExecuteBB(BI, BI->getSuccessor(0), DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } } else if (BI->getSuccessor(1)->getSinglePredecessor()) { // If Successor #0 has multiple preds, we may be able to conditionally @@ -4099,7 +4491,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (Succ1TI->getNumSuccessors() == 1 && Succ1TI->getSuccessor(0) == BI->getSuccessor(0)) if (SpeculativelyExecuteBB(BI, BI->getSuccessor(1), DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } // If this is a branch on a phi node in the current block, thread control @@ -4107,14 +4499,14 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (PHINode *PN = dyn_cast<PHINode>(BI->getCondition())) if (PN->getParent() == BI->getParent()) if (FoldCondBranchOnPHI(BI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; // Scan predecessor blocks for conditional branches. for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) if (BranchInst *PBI = dyn_cast<BranchInst>((*PI)->getTerminator())) if (PBI != BI && PBI->isConditional()) if (SimplifyCondBranchToCondBranch(PBI, BI)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; return false; } @@ -4195,7 +4587,7 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) { // Remove basic blocks that have no predecessors (except the entry block)... // or that just have themself as a predecessor. These are unreachable. - if ((pred_begin(BB) == pred_end(BB) && + if ((pred_empty(BB) && BB != &BB->getParent()->getEntryBlock()) || BB->getSinglePredecessor() == BB) { DEBUG(dbgs() << "Removing BB: \n" << *BB); @@ -4258,6 +4650,7 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) { /// of the CFG. It returns true if a modification was made. /// bool llvm::SimplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI, - const DataLayout *DL) { - return SimplifyCFGOpt(TTI, DL).run(BB); + unsigned BonusInstThreshold, const DataLayout *DL, + AssumptionCache *AC) { + return SimplifyCFGOpt(TTI, BonusInstThreshold, DL, AC).run(BB); } diff --git a/lib/Transforms/Utils/SimplifyIndVar.cpp b/lib/Transforms/Utils/SimplifyIndVar.cpp index b284e6f6c1f6..f8aa1d3eec12 100644 --- a/lib/Transforms/Utils/SimplifyIndVar.cpp +++ b/lib/Transforms/Utils/SimplifyIndVar.cpp @@ -40,7 +40,7 @@ STATISTIC(NumElimRem , "Number of IV remainder operations eliminated"); STATISTIC(NumElimCmp , "Number of IV comparisons eliminated"); namespace { - /// SimplifyIndvar - This is a utility for simplifying induction variables + /// This is a utility for simplifying induction variables /// based on ScalarEvolution. It is the primary instrument of the /// IndvarSimplify pass, but it may also be directly invoked to cleanup after /// other loop passes that preserve SCEV. @@ -80,13 +80,14 @@ namespace { void eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand); void eliminateIVRemainder(BinaryOperator *Rem, Value *IVOperand, bool IsSigned); + bool strengthenOverflowingOperation(BinaryOperator *OBO, Value *IVOperand); Instruction *splitOverflowIntrinsic(Instruction *IVUser, const DominatorTree *DT); }; } -/// foldIVUser - Fold an IV operand into its use. This removes increments of an +/// Fold an IV operand into its use. This removes increments of an /// aligned IV when used by a instruction that ignores the low bits. /// /// IVOperand is guaranteed SCEVable, but UseInst may not be. @@ -152,7 +153,7 @@ Value *SimplifyIndvar::foldIVUser(Instruction *UseInst, Instruction *IVOperand) return IVSrc; } -/// eliminateIVComparison - SimplifyIVUsers helper for eliminating useless +/// SimplifyIVUsers helper for eliminating useless /// comparisons against an induction variable. void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand) { unsigned IVOperIdx = 0; @@ -188,7 +189,7 @@ void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand) { DeadInsts.push_back(ICmp); } -/// eliminateIVRemainder - SimplifyIVUsers helper for eliminating useless +/// SimplifyIVUsers helper for eliminating useless /// remainder operations operating on an induction variable. void SimplifyIndvar::eliminateIVRemainder(BinaryOperator *Rem, Value *IVOperand, @@ -239,7 +240,7 @@ void SimplifyIndvar::eliminateIVRemainder(BinaryOperator *Rem, DeadInsts.push_back(Rem); } -/// eliminateIVUser - Eliminate an operation that consumes a simple IV and has +/// Eliminate an operation that consumes a simple IV and has /// no observable side-effect given the range of IV values. /// IVOperand is guaranteed SCEVable, but UseInst may not be. bool SimplifyIndvar::eliminateIVUser(Instruction *UseInst, @@ -271,6 +272,120 @@ bool SimplifyIndvar::eliminateIVUser(Instruction *UseInst, return true; } +/// Annotate BO with nsw / nuw if it provably does not signed-overflow / +/// unsigned-overflow. Returns true if anything changed, false otherwise. +bool SimplifyIndvar::strengthenOverflowingOperation(BinaryOperator *BO, + Value *IVOperand) { + + // Currently we only handle instructions of the form "add <indvar> <value>" + // and "sub <indvar> <value>". + unsigned Op = BO->getOpcode(); + if (!(Op == Instruction::Add || Op == Instruction::Sub)) + return false; + + // If BO is already both nuw and nsw then there is nothing left to do + if (BO->hasNoUnsignedWrap() && BO->hasNoSignedWrap()) + return false; + + IntegerType *IT = cast<IntegerType>(IVOperand->getType()); + Value *OtherOperand = nullptr; + int OtherOperandIdx = -1; + if (BO->getOperand(0) == IVOperand) { + OtherOperand = BO->getOperand(1); + OtherOperandIdx = 1; + } else { + assert(BO->getOperand(1) == IVOperand && "only other use!"); + OtherOperand = BO->getOperand(0); + OtherOperandIdx = 0; + } + + bool Changed = false; + const SCEV *OtherOpSCEV = SE->getSCEV(OtherOperand); + if (OtherOpSCEV == SE->getCouldNotCompute()) + return false; + + if (Op == Instruction::Sub) { + // If the subtraction is of the form "sub <indvar>, <op>", then pretend it + // is "add <indvar>, -<op>" and continue, else bail out. + if (OtherOperandIdx != 1) + return false; + + OtherOpSCEV = SE->getNegativeSCEV(OtherOpSCEV); + } + + const SCEV *IVOpSCEV = SE->getSCEV(IVOperand); + const SCEV *ZeroSCEV = SE->getConstant(IVOpSCEV->getType(), 0); + + if (!BO->hasNoSignedWrap()) { + // Upgrade the add to an "add nsw" if we can prove that it will never + // sign-overflow or sign-underflow. + + const SCEV *SignedMax = + SE->getConstant(APInt::getSignedMaxValue(IT->getBitWidth())); + const SCEV *SignedMin = + SE->getConstant(APInt::getSignedMinValue(IT->getBitWidth())); + + // The addition "IVOperand + OtherOp" does not sign-overflow if the result + // is sign-representable in 2's complement in the given bit-width. + // + // If OtherOp is SLT 0, then for an IVOperand in [SignedMin - OtherOp, + // SignedMax], "IVOperand + OtherOp" is in [SignedMin, SignedMax + OtherOp]. + // Everything in [SignedMin, SignedMax + OtherOp] is representable since + // SignedMax + OtherOp is at least -1. + // + // If OtherOp is SGE 0, then for an IVOperand in [SignedMin, SignedMax - + // OtherOp], "IVOperand + OtherOp" is in [SignedMin + OtherOp, SignedMax]. + // Everything in [SignedMin + OtherOp, SignedMax] is representable since + // SignedMin + OtherOp is at most -1. + // + // It follows that for all values of IVOperand in [SignedMin - smin(0, + // OtherOp), SignedMax - smax(0, OtherOp)] the result of the add is + // representable (i.e. there is no sign-overflow). + + const SCEV *UpperDelta = SE->getSMaxExpr(ZeroSCEV, OtherOpSCEV); + const SCEV *UpperLimit = SE->getMinusSCEV(SignedMax, UpperDelta); + + bool NeverSignedOverflows = + SE->isKnownPredicate(ICmpInst::ICMP_SLE, IVOpSCEV, UpperLimit); + + if (NeverSignedOverflows) { + const SCEV *LowerDelta = SE->getSMinExpr(ZeroSCEV, OtherOpSCEV); + const SCEV *LowerLimit = SE->getMinusSCEV(SignedMin, LowerDelta); + + bool NeverSignedUnderflows = + SE->isKnownPredicate(ICmpInst::ICMP_SGE, IVOpSCEV, LowerLimit); + if (NeverSignedUnderflows) { + BO->setHasNoSignedWrap(true); + Changed = true; + } + } + } + + if (!BO->hasNoUnsignedWrap()) { + // Upgrade the add computing "IVOperand + OtherOp" to an "add nuw" if we can + // prove that it will never unsigned-overflow (i.e. the result will always + // be representable in the given bit-width). + // + // "IVOperand + OtherOp" is unsigned-representable in 2's complement iff it + // does not produce a carry. "IVOperand + OtherOp" produces no carry iff + // IVOperand ULE (UnsignedMax - OtherOp). + + const SCEV *UnsignedMax = + SE->getConstant(APInt::getMaxValue(IT->getBitWidth())); + const SCEV *UpperLimit = SE->getMinusSCEV(UnsignedMax, OtherOpSCEV); + + bool NeverUnsignedOverflows = + SE->isKnownPredicate(ICmpInst::ICMP_ULE, IVOpSCEV, UpperLimit); + + if (NeverUnsignedOverflows) { + BO->setHasNoUnsignedWrap(true); + Changed = true; + } + } + + return Changed; +} + /// \brief Split sadd.with.overflow into add + sadd.with.overflow to allow /// analysis and optimization. /// @@ -334,8 +449,7 @@ Instruction *SimplifyIndvar::splitOverflowIntrinsic(Instruction *IVUser, return AddInst; } -/// pushIVUsers - Add all uses of Def to the current IV's worklist. -/// +/// Add all uses of Def to the current IV's worklist. static void pushIVUsers( Instruction *Def, SmallPtrSet<Instruction*,16> &Simplified, @@ -348,12 +462,12 @@ static void pushIVUsers( // Also ensure unique worklist users. // If Def is a LoopPhi, it may not be in the Simplified set, so check for // self edges first. - if (UI != Def && Simplified.insert(UI)) + if (UI != Def && Simplified.insert(UI).second) SimpleIVUsers.push_back(std::make_pair(UI, Def)); } } -/// isSimpleIVUser - Return true if this instruction generates a simple SCEV +/// Return true if this instruction generates a simple SCEV /// expression in terms of that IV. /// /// This is similar to IVUsers' isInteresting() but processes each instruction @@ -374,7 +488,7 @@ static bool isSimpleIVUser(Instruction *I, const Loop *L, ScalarEvolution *SE) { return false; } -/// simplifyUsers - Iteratively perform simplification on a worklist of users +/// Iteratively perform simplification on a worklist of users /// of the specified induction variable. Each successive simplification may push /// more users which may themselves be candidates for simplification. /// @@ -431,6 +545,16 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { pushIVUsers(IVOperand, Simplified, SimpleIVUsers); continue; } + + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(UseOper.first)) { + if (isa<OverflowingBinaryOperator>(BO) && + strengthenOverflowingOperation(BO, IVOperand)) { + // re-queue uses of the now modified binary operator and fall + // through to the checks that remain. + pushIVUsers(IVOperand, Simplified, SimpleIVUsers); + } + } + CastInst *Cast = dyn_cast<CastInst>(UseOper.first); if (V && Cast) { V->visitCast(Cast); @@ -446,7 +570,7 @@ namespace llvm { void IVVisitor::anchor() { } -/// simplifyUsersOfIV - Simplify instructions that use this induction variable +/// Simplify instructions that use this induction variable /// by using ScalarEvolution to analyze the IV's recurrence. bool simplifyUsersOfIV(PHINode *CurrIV, ScalarEvolution *SE, LPPassManager *LPM, SmallVectorImpl<WeakVH> &Dead, IVVisitor *V) @@ -457,7 +581,7 @@ bool simplifyUsersOfIV(PHINode *CurrIV, ScalarEvolution *SE, LPPassManager *LPM, return SIV.hasChanged(); } -/// simplifyLoopIVs - Simplify users of induction variables within this +/// Simplify users of induction variables within this /// loop. This does not actually change or add IVs. bool simplifyLoopIVs(Loop *L, ScalarEvolution *SE, LPPassManager *LPM, SmallVectorImpl<WeakVH> &Dead) { diff --git a/lib/Transforms/Utils/SimplifyInstructions.cpp b/lib/Transforms/Utils/SimplifyInstructions.cpp index 33b36378027d..cc97098d010a 100644 --- a/lib/Transforms/Utils/SimplifyInstructions.cpp +++ b/lib/Transforms/Utils/SimplifyInstructions.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" @@ -41,6 +42,7 @@ namespace { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetLibraryInfo>(); } @@ -52,6 +54,8 @@ namespace { DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); const DataLayout *DL = DLP ? &DLP->getDataLayout() : nullptr; const TargetLibraryInfo *TLI = &getAnalysis<TargetLibraryInfo>(); + AssumptionCache *AC = + &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); SmallPtrSet<const Instruction*, 8> S1, S2, *ToSimplify = &S1, *Next = &S2; bool Changed = false; @@ -68,7 +72,7 @@ namespace { continue; // Don't waste time simplifying unused instructions. if (!I->use_empty()) - if (Value *V = SimplifyInstruction(I, DL, TLI, DT)) { + if (Value *V = SimplifyInstruction(I, DL, TLI, DT, AC)) { // Mark all uses for resimplification next time round the loop. for (User *U : I->users()) Next->insert(cast<Instruction>(U)); @@ -101,6 +105,7 @@ namespace { char InstSimplifier::ID = 0; INITIALIZE_PASS_BEGIN(InstSimplifier, "instsimplify", "Remove redundant instructions", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) INITIALIZE_PASS_END(InstSimplifier, "instsimplify", "Remove redundant instructions", false, false) diff --git a/lib/Transforms/Utils/SimplifyLibCalls.cpp b/lib/Transforms/Utils/SimplifyLibCalls.cpp index 3b61bb575a8d..5b4647ddcb5e 100644 --- a/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -27,65 +27,43 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/CommandLine.h" #include "llvm/Target/TargetLibraryInfo.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" using namespace llvm; +using namespace PatternMatch; static cl::opt<bool> -ColdErrorCalls("error-reporting-is-cold", cl::init(true), - cl::Hidden, cl::desc("Treat error-reporting calls as cold")); - -/// This class is the abstract base class for the set of optimizations that -/// corresponds to one library call. -namespace { -class LibCallOptimization { -protected: - Function *Caller; - const DataLayout *DL; - const TargetLibraryInfo *TLI; - const LibCallSimplifier *LCS; - LLVMContext* Context; -public: - LibCallOptimization() { } - virtual ~LibCallOptimization() {} - - /// callOptimizer - This pure virtual method is implemented by base classes to - /// do various optimizations. If this returns null then no transformation was - /// performed. If it returns CI, then it transformed the call and CI is to be - /// deleted. If it returns something else, replace CI with the new value and - /// delete CI. - virtual Value *callOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) - =0; - - /// ignoreCallingConv - Returns false if this transformation could possibly - /// change the calling convention. - virtual bool ignoreCallingConv() { return false; } - - Value *optimizeCall(CallInst *CI, const DataLayout *DL, - const TargetLibraryInfo *TLI, - const LibCallSimplifier *LCS, IRBuilder<> &B) { - Caller = CI->getParent()->getParent(); - this->DL = DL; - this->TLI = TLI; - this->LCS = LCS; - if (CI->getCalledFunction()) - Context = &CI->getCalledFunction()->getContext(); + ColdErrorCalls("error-reporting-is-cold", cl::init(true), cl::Hidden, + cl::desc("Treat error-reporting calls as cold")); - // We never change the calling convention. - if (!ignoreCallingConv() && CI->getCallingConv() != llvm::CallingConv::C) - return nullptr; +static cl::opt<bool> + EnableUnsafeFPShrink("enable-double-float-shrink", cl::Hidden, + cl::init(false), + cl::desc("Enable unsafe double to float " + "shrinking for math lib calls")); - return callOptimizer(CI->getCalledFunction(), CI, B); - } -}; //===----------------------------------------------------------------------===// // Helper Functions //===----------------------------------------------------------------------===// +static bool ignoreCallingConv(LibFunc::Func Func) { + switch (Func) { + case LibFunc::abs: + case LibFunc::labs: + case LibFunc::llabs: + case LibFunc::strlen: + return true; + default: + return false; + } + llvm_unreachable("All cases should be covered in the switch."); +} + /// isOnlyUsedInZeroEqualityComparison - Return true if it only matters that the /// value is equal or not-equal to zero. static bool isOnlyUsedInZeroEqualityComparison(Value *V) { @@ -138,1908 +116,1739 @@ static bool hasUnaryFloatFn(const TargetLibraryInfo *TLI, Type *Ty, } } -//===----------------------------------------------------------------------===// -// Fortified Library Call Optimizations -//===----------------------------------------------------------------------===// - -struct FortifiedLibCallOptimization : public LibCallOptimization { -protected: - virtual bool isFoldable(unsigned SizeCIOp, unsigned SizeArgOp, - bool isString) const = 0; -}; - -struct InstFortifiedLibCallOptimization : public FortifiedLibCallOptimization { - CallInst *CI; - - bool isFoldable(unsigned SizeCIOp, unsigned SizeArgOp, - bool isString) const override { - if (CI->getArgOperand(SizeCIOp) == CI->getArgOperand(SizeArgOp)) - return true; - if (ConstantInt *SizeCI = - dyn_cast<ConstantInt>(CI->getArgOperand(SizeCIOp))) { - if (SizeCI->isAllOnesValue()) - return true; - if (isString) { - uint64_t Len = GetStringLength(CI->getArgOperand(SizeArgOp)); - // If the length is 0 we don't know how long it is and so we can't - // remove the check. - if (Len == 0) return false; - return SizeCI->getZExtValue() >= Len; - } - if (ConstantInt *Arg = dyn_cast<ConstantInt>( - CI->getArgOperand(SizeArgOp))) - return SizeCI->getZExtValue() >= Arg->getZExtValue(); - } +/// \brief Returns whether \p F matches the signature expected for the +/// string/memory copying library function \p Func. +/// Acceptable functions are st[rp][n]?cpy, memove, memcpy, and memset. +/// Their fortified (_chk) counterparts are also accepted. +static bool checkStringCopyLibFuncSignature(Function *F, LibFunc::Func Func, + const DataLayout *DL) { + FunctionType *FT = F->getFunctionType(); + LLVMContext &Context = F->getContext(); + Type *PCharTy = Type::getInt8PtrTy(Context); + Type *SizeTTy = DL ? DL->getIntPtrType(Context) : nullptr; + unsigned NumParams = FT->getNumParams(); + + // All string libfuncs return the same type as the first parameter. + if (FT->getReturnType() != FT->getParamType(0)) return false; - } -}; - -struct MemCpyChkOpt : public InstFortifiedLibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - this->CI = CI; - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); - - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - FT->getParamType(2) != DL->getIntPtrType(Context) || - FT->getParamType(3) != DL->getIntPtrType(Context)) - return nullptr; - if (isFoldable(3, 2, false)) { - B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); - return CI->getArgOperand(0); - } - return nullptr; + switch (Func) { + default: + llvm_unreachable("Can't check signature for non-string-copy libfunc."); + case LibFunc::stpncpy_chk: + case LibFunc::strncpy_chk: + --NumParams; // fallthrough + case LibFunc::stpncpy: + case LibFunc::strncpy: { + if (NumParams != 3 || FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != PCharTy || !FT->getParamType(2)->isIntegerTy()) + return false; + break; + } + case LibFunc::strcpy_chk: + case LibFunc::stpcpy_chk: + --NumParams; // fallthrough + case LibFunc::stpcpy: + case LibFunc::strcpy: { + if (NumParams != 2 || FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != PCharTy) + return false; + break; + } + case LibFunc::memmove_chk: + case LibFunc::memcpy_chk: + --NumParams; // fallthrough + case LibFunc::memmove: + case LibFunc::memcpy: { + if (NumParams != 3 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || FT->getParamType(2) != SizeTTy) + return false; + break; } -}; - -struct MemMoveChkOpt : public InstFortifiedLibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - this->CI = CI; - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); + case LibFunc::memset_chk: + --NumParams; // fallthrough + case LibFunc::memset: { + if (NumParams != 3 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isIntegerTy() || FT->getParamType(2) != SizeTTy) + return false; + break; + } + } + // If this is a fortified libcall, the last parameter is a size_t. + if (NumParams == FT->getNumParams() - 1) + return FT->getParamType(FT->getNumParams() - 1) == SizeTTy; + return true; +} - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - FT->getParamType(2) != DL->getIntPtrType(Context) || - FT->getParamType(3) != DL->getIntPtrType(Context)) - return nullptr; +//===----------------------------------------------------------------------===// +// String and Memory Library Call Optimizations +//===----------------------------------------------------------------------===// - if (isFoldable(3, 2, false)) { - B.CreateMemMove(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); - return CI->getArgOperand(0); - } +Value *LibCallSimplifier::optimizeStrCat(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "strcat" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2|| + FT->getReturnType() != B.getInt8PtrTy() || + FT->getParamType(0) != FT->getReturnType() || + FT->getParamType(1) != FT->getReturnType()) return nullptr; - } -}; -struct MemSetChkOpt : public InstFortifiedLibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - this->CI = CI; - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); - - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isIntegerTy() || - FT->getParamType(2) != DL->getIntPtrType(Context) || - FT->getParamType(3) != DL->getIntPtrType(Context)) - return nullptr; + // Extract some information from the instruction + Value *Dst = CI->getArgOperand(0); + Value *Src = CI->getArgOperand(1); - if (isFoldable(3, 2, false)) { - Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), - false); - B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); - return CI->getArgOperand(0); - } + // See if we can get the length of the input string. + uint64_t Len = GetStringLength(Src); + if (Len == 0) return nullptr; - } -}; - -struct StrCpyChkOpt : public InstFortifiedLibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - this->CI = CI; - StringRef Name = Callee->getName(); - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); + --Len; // Unbias length. - // Check if this has the right signature. - if (FT->getNumParams() != 3 || - FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != Type::getInt8PtrTy(Context) || - FT->getParamType(2) != DL->getIntPtrType(Context)) - return nullptr; + // Handle the simple, do-nothing case: strcat(x, "") -> x + if (Len == 0) + return Dst; - Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); - if (Dst == Src) // __strcpy_chk(x,x) -> x - return Src; - - // If a) we don't have any length information, or b) we know this will - // fit then just lower to a plain strcpy. Otherwise we'll keep our - // strcpy_chk call which may fail at runtime if the size is too long. - // TODO: It might be nice to get a maximum length out of the possible - // string lengths for varying. - if (isFoldable(2, 1, true)) { - Value *Ret = EmitStrCpy(Dst, Src, B, DL, TLI, Name.substr(2, 6)); - return Ret; - } else { - // Maybe we can stil fold __strcpy_chk to __memcpy_chk. - uint64_t Len = GetStringLength(Src); - if (Len == 0) return nullptr; - - // This optimization require DataLayout. - if (!DL) return nullptr; - - Value *Ret = - EmitMemCpyChk(Dst, Src, - ConstantInt::get(DL->getIntPtrType(Context), Len), - CI->getArgOperand(2), B, DL, TLI); - return Ret; - } + // These optimizations require DataLayout. + if (!DL) return nullptr; - } -}; -struct StpCpyChkOpt : public InstFortifiedLibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - this->CI = CI; - StringRef Name = Callee->getName(); - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); + return emitStrLenMemCpy(Src, Dst, Len, B); +} - // Check if this has the right signature. - if (FT->getNumParams() != 3 || - FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != Type::getInt8PtrTy(Context) || - FT->getParamType(2) != DL->getIntPtrType(FT->getParamType(0))) - return nullptr; +Value *LibCallSimplifier::emitStrLenMemCpy(Value *Src, Value *Dst, uint64_t Len, + IRBuilder<> &B) { + // We need to find the end of the destination string. That's where the + // memory is to be moved to. We just generate a call to strlen. + Value *DstLen = EmitStrLen(Dst, B, DL, TLI); + if (!DstLen) + return nullptr; - Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); - if (Dst == Src) { // stpcpy(x,x) -> x+strlen(x) - Value *StrLen = EmitStrLen(Src, B, DL, TLI); - return StrLen ? B.CreateInBoundsGEP(Dst, StrLen) : nullptr; - } + // Now that we have the destination's length, we must index into the + // destination's pointer to get the actual memcpy destination (end of + // the string .. we're concatenating). + Value *CpyDst = B.CreateGEP(Dst, DstLen, "endptr"); + + // We have enough information to now generate the memcpy call to do the + // concatenation for us. Make a memcpy to copy the nul byte with align = 1. + B.CreateMemCpy( + CpyDst, Src, + ConstantInt::get(DL->getIntPtrType(Src->getContext()), Len + 1), 1); + return Dst; +} - // If a) we don't have any length information, or b) we know this will - // fit then just lower to a plain stpcpy. Otherwise we'll keep our - // stpcpy_chk call which may fail at runtime if the size is too long. - // TODO: It might be nice to get a maximum length out of the possible - // string lengths for varying. - if (isFoldable(2, 1, true)) { - Value *Ret = EmitStrCpy(Dst, Src, B, DL, TLI, Name.substr(2, 6)); - return Ret; - } else { - // Maybe we can stil fold __stpcpy_chk to __memcpy_chk. - uint64_t Len = GetStringLength(Src); - if (Len == 0) return nullptr; - - // This optimization require DataLayout. - if (!DL) return nullptr; - - Type *PT = FT->getParamType(0); - Value *LenV = ConstantInt::get(DL->getIntPtrType(PT), Len); - Value *DstEnd = B.CreateGEP(Dst, - ConstantInt::get(DL->getIntPtrType(PT), - Len - 1)); - if (!EmitMemCpyChk(Dst, Src, LenV, CI->getArgOperand(2), B, DL, TLI)) - return nullptr; - return DstEnd; - } +Value *LibCallSimplifier::optimizeStrNCat(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "strncat" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || FT->getReturnType() != B.getInt8PtrTy() || + FT->getParamType(0) != FT->getReturnType() || + FT->getParamType(1) != FT->getReturnType() || + !FT->getParamType(2)->isIntegerTy()) return nullptr; - } -}; - -struct StrNCpyChkOpt : public InstFortifiedLibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - this->CI = CI; - StringRef Name = Callee->getName(); - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != Type::getInt8PtrTy(Context) || - !FT->getParamType(2)->isIntegerTy() || - FT->getParamType(3) != DL->getIntPtrType(Context)) - return nullptr; + // Extract some information from the instruction + Value *Dst = CI->getArgOperand(0); + Value *Src = CI->getArgOperand(1); + uint64_t Len; - if (isFoldable(3, 2, false)) { - Value *Ret = EmitStrNCpy(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), B, DL, TLI, - Name.substr(2, 7)); - return Ret; - } + // We don't do anything if length is not constant + if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getArgOperand(2))) + Len = LengthArg->getZExtValue(); + else return nullptr; - } -}; -//===----------------------------------------------------------------------===// -// String and Memory Library Call Optimizations -//===----------------------------------------------------------------------===// + // See if we can get the length of the input string. + uint64_t SrcLen = GetStringLength(Src); + if (SrcLen == 0) + return nullptr; + --SrcLen; // Unbias length. -struct StrCatOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strcat" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getReturnType() != B.getInt8PtrTy() || - FT->getParamType(0) != FT->getReturnType() || - FT->getParamType(1) != FT->getReturnType()) - return nullptr; + // Handle the simple, do-nothing cases: + // strncat(x, "", c) -> x + // strncat(x, c, 0) -> x + if (SrcLen == 0 || Len == 0) + return Dst; - // Extract some information from the instruction - Value *Dst = CI->getArgOperand(0); - Value *Src = CI->getArgOperand(1); + // These optimizations require DataLayout. + if (!DL) + return nullptr; - // See if we can get the length of the input string. - uint64_t Len = GetStringLength(Src); - if (Len == 0) return nullptr; - --Len; // Unbias length. + // We don't optimize this case + if (Len < SrcLen) + return nullptr; - // Handle the simple, do-nothing case: strcat(x, "") -> x - if (Len == 0) - return Dst; + // strncat(x, s, c) -> strcat(x, s) + // s is constant so the strcat can be optimized further + return emitStrLenMemCpy(Src, Dst, SrcLen, B); +} - // These optimizations require DataLayout. - if (!DL) return nullptr; +Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "strchr" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || FT->getReturnType() != B.getInt8PtrTy() || + FT->getParamType(0) != FT->getReturnType() || + !FT->getParamType(1)->isIntegerTy(32)) + return nullptr; - return emitStrLenMemCpy(Src, Dst, Len, B); - } + Value *SrcStr = CI->getArgOperand(0); - Value *emitStrLenMemCpy(Value *Src, Value *Dst, uint64_t Len, - IRBuilder<> &B) { - // We need to find the end of the destination string. That's where the - // memory is to be moved to. We just generate a call to strlen. - Value *DstLen = EmitStrLen(Dst, B, DL, TLI); - if (!DstLen) + // If the second operand is non-constant, see if we can compute the length + // of the input string and turn this into memchr. + ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + if (!CharC) { + // These optimizations require DataLayout. + if (!DL) return nullptr; - // Now that we have the destination's length, we must index into the - // destination's pointer to get the actual memcpy destination (end of - // the string .. we're concatenating). - Value *CpyDst = B.CreateGEP(Dst, DstLen, "endptr"); + uint64_t Len = GetStringLength(SrcStr); + if (Len == 0 || !FT->getParamType(1)->isIntegerTy(32)) // memchr needs i32. + return nullptr; - // We have enough information to now generate the memcpy call to do the - // concatenation for us. Make a memcpy to copy the nul byte with align = 1. - B.CreateMemCpy(CpyDst, Src, - ConstantInt::get(DL->getIntPtrType(*Context), Len + 1), 1); - return Dst; + return EmitMemChr( + SrcStr, CI->getArgOperand(1), // include nul. + ConstantInt::get(DL->getIntPtrType(CI->getContext()), Len), B, DL, TLI); } -}; - -struct StrNCatOpt : public StrCatOpt { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strncat" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || - FT->getReturnType() != B.getInt8PtrTy() || - FT->getParamType(0) != FT->getReturnType() || - FT->getParamType(1) != FT->getReturnType() || - !FT->getParamType(2)->isIntegerTy()) - return nullptr; - // Extract some information from the instruction - Value *Dst = CI->getArgOperand(0); - Value *Src = CI->getArgOperand(1); - uint64_t Len; + // Otherwise, the character is a constant, see if the first argument is + // a string literal. If so, we can constant fold. + StringRef Str; + if (!getConstantStringInfo(SrcStr, Str)) { + if (DL && CharC->isZero()) // strchr(p, 0) -> p + strlen(p) + return B.CreateGEP(SrcStr, EmitStrLen(SrcStr, B, DL, TLI), "strchr"); + return nullptr; + } - // We don't do anything if length is not constant - if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getArgOperand(2))) - Len = LengthArg->getZExtValue(); - else - return nullptr; + // Compute the offset, make sure to handle the case when we're searching for + // zero (a weird way to spell strlen). + size_t I = (0xFF & CharC->getSExtValue()) == 0 + ? Str.size() + : Str.find(CharC->getSExtValue()); + if (I == StringRef::npos) // Didn't find the char. strchr returns null. + return Constant::getNullValue(CI->getType()); - // See if we can get the length of the input string. - uint64_t SrcLen = GetStringLength(Src); - if (SrcLen == 0) return nullptr; - --SrcLen; // Unbias length. + // strchr(s+n,c) -> gep(s+n+i,c) + return B.CreateGEP(SrcStr, B.getInt64(I), "strchr"); +} - // Handle the simple, do-nothing cases: - // strncat(x, "", c) -> x - // strncat(x, c, 0) -> x - if (SrcLen == 0 || Len == 0) return Dst; +Value *LibCallSimplifier::optimizeStrRChr(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "strrchr" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || FT->getReturnType() != B.getInt8PtrTy() || + FT->getParamType(0) != FT->getReturnType() || + !FT->getParamType(1)->isIntegerTy(32)) + return nullptr; - // These optimizations require DataLayout. - if (!DL) return nullptr; - - // We don't optimize this case - if (Len < SrcLen) return nullptr; - - // strncat(x, s, c) -> strcat(x, s) - // s is constant so the strcat can be optimized further - return emitStrLenMemCpy(Src, Dst, SrcLen, B); - } -}; - -struct StrChrOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strchr" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getReturnType() != B.getInt8PtrTy() || - FT->getParamType(0) != FT->getReturnType() || - !FT->getParamType(1)->isIntegerTy(32)) - return nullptr; + Value *SrcStr = CI->getArgOperand(0); + ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); - Value *SrcStr = CI->getArgOperand(0); + // Cannot fold anything if we're not looking for a constant. + if (!CharC) + return nullptr; - // If the second operand is non-constant, see if we can compute the length - // of the input string and turn this into memchr. - ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); - if (!CharC) { - // These optimizations require DataLayout. - if (!DL) return nullptr; + StringRef Str; + if (!getConstantStringInfo(SrcStr, Str)) { + // strrchr(s, 0) -> strchr(s, 0) + if (DL && CharC->isZero()) + return EmitStrChr(SrcStr, '\0', B, DL, TLI); + return nullptr; + } - uint64_t Len = GetStringLength(SrcStr); - if (Len == 0 || !FT->getParamType(1)->isIntegerTy(32))// memchr needs i32. - return nullptr; + // Compute the offset. + size_t I = (0xFF & CharC->getSExtValue()) == 0 + ? Str.size() + : Str.rfind(CharC->getSExtValue()); + if (I == StringRef::npos) // Didn't find the char. Return null. + return Constant::getNullValue(CI->getType()); - return EmitMemChr(SrcStr, CI->getArgOperand(1), // include nul. - ConstantInt::get(DL->getIntPtrType(*Context), Len), - B, DL, TLI); - } + // strrchr(s+n,c) -> gep(s+n+i,c) + return B.CreateGEP(SrcStr, B.getInt64(I), "strrchr"); +} - // Otherwise, the character is a constant, see if the first argument is - // a string literal. If so, we can constant fold. - StringRef Str; - if (!getConstantStringInfo(SrcStr, Str)) { - if (DL && CharC->isZero()) // strchr(p, 0) -> p + strlen(p) - return B.CreateGEP(SrcStr, EmitStrLen(SrcStr, B, DL, TLI), "strchr"); - return nullptr; - } +Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "strcmp" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || !FT->getReturnType()->isIntegerTy(32) || + FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != B.getInt8PtrTy()) + return nullptr; - // Compute the offset, make sure to handle the case when we're searching for - // zero (a weird way to spell strlen). - size_t I = (0xFF & CharC->getSExtValue()) == 0 ? - Str.size() : Str.find(CharC->getSExtValue()); - if (I == StringRef::npos) // Didn't find the char. strchr returns null. - return Constant::getNullValue(CI->getType()); + Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1); + if (Str1P == Str2P) // strcmp(x,x) -> 0 + return ConstantInt::get(CI->getType(), 0); - // strchr(s+n,c) -> gep(s+n+i,c) - return B.CreateGEP(SrcStr, B.getInt64(I), "strchr"); - } -}; + StringRef Str1, Str2; + bool HasStr1 = getConstantStringInfo(Str1P, Str1); + bool HasStr2 = getConstantStringInfo(Str2P, Str2); -struct StrRChrOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strrchr" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getReturnType() != B.getInt8PtrTy() || - FT->getParamType(0) != FT->getReturnType() || - !FT->getParamType(1)->isIntegerTy(32)) - return nullptr; + // strcmp(x, y) -> cnst (if both x and y are constant strings) + if (HasStr1 && HasStr2) + return ConstantInt::get(CI->getType(), Str1.compare(Str2)); - Value *SrcStr = CI->getArgOperand(0); - ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + if (HasStr1 && Str1.empty()) // strcmp("", x) -> -*x + return B.CreateNeg( + B.CreateZExt(B.CreateLoad(Str2P, "strcmpload"), CI->getType())); - // Cannot fold anything if we're not looking for a constant. - if (!CharC) - return nullptr; + if (HasStr2 && Str2.empty()) // strcmp(x,"") -> *x + return B.CreateZExt(B.CreateLoad(Str1P, "strcmpload"), CI->getType()); - StringRef Str; - if (!getConstantStringInfo(SrcStr, Str)) { - // strrchr(s, 0) -> strchr(s, 0) - if (DL && CharC->isZero()) - return EmitStrChr(SrcStr, '\0', B, DL, TLI); + // strcmp(P, "x") -> memcmp(P, "x", 2) + uint64_t Len1 = GetStringLength(Str1P); + uint64_t Len2 = GetStringLength(Str2P); + if (Len1 && Len2) { + // These optimizations require DataLayout. + if (!DL) return nullptr; - } - - // Compute the offset. - size_t I = (0xFF & CharC->getSExtValue()) == 0 ? - Str.size() : Str.rfind(CharC->getSExtValue()); - if (I == StringRef::npos) // Didn't find the char. Return null. - return Constant::getNullValue(CI->getType()); - // strrchr(s+n,c) -> gep(s+n+i,c) - return B.CreateGEP(SrcStr, B.getInt64(I), "strrchr"); + return EmitMemCmp(Str1P, Str2P, + ConstantInt::get(DL->getIntPtrType(CI->getContext()), + std::min(Len1, Len2)), + B, DL, TLI); } -}; - -struct StrCmpOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strcmp" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - !FT->getReturnType()->isIntegerTy(32) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != B.getInt8PtrTy()) - return nullptr; - Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1); - if (Str1P == Str2P) // strcmp(x,x) -> 0 - return ConstantInt::get(CI->getType(), 0); + return nullptr; +} - StringRef Str1, Str2; - bool HasStr1 = getConstantStringInfo(Str1P, Str1); - bool HasStr2 = getConstantStringInfo(Str2P, Str2); +Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "strncmp" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || !FT->getReturnType()->isIntegerTy(32) || + FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != B.getInt8PtrTy() || + !FT->getParamType(2)->isIntegerTy()) + return nullptr; - // strcmp(x, y) -> cnst (if both x and y are constant strings) - if (HasStr1 && HasStr2) - return ConstantInt::get(CI->getType(), Str1.compare(Str2)); + Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1); + if (Str1P == Str2P) // strncmp(x,x,n) -> 0 + return ConstantInt::get(CI->getType(), 0); - if (HasStr1 && Str1.empty()) // strcmp("", x) -> -*x - return B.CreateNeg(B.CreateZExt(B.CreateLoad(Str2P, "strcmpload"), - CI->getType())); + // Get the length argument if it is constant. + uint64_t Length; + if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getArgOperand(2))) + Length = LengthArg->getZExtValue(); + else + return nullptr; - if (HasStr2 && Str2.empty()) // strcmp(x,"") -> *x - return B.CreateZExt(B.CreateLoad(Str1P, "strcmpload"), CI->getType()); + if (Length == 0) // strncmp(x,y,0) -> 0 + return ConstantInt::get(CI->getType(), 0); - // strcmp(P, "x") -> memcmp(P, "x", 2) - uint64_t Len1 = GetStringLength(Str1P); - uint64_t Len2 = GetStringLength(Str2P); - if (Len1 && Len2) { - // These optimizations require DataLayout. - if (!DL) return nullptr; + if (DL && Length == 1) // strncmp(x,y,1) -> memcmp(x,y,1) + return EmitMemCmp(Str1P, Str2P, CI->getArgOperand(2), B, DL, TLI); - return EmitMemCmp(Str1P, Str2P, - ConstantInt::get(DL->getIntPtrType(*Context), - std::min(Len1, Len2)), B, DL, TLI); - } + StringRef Str1, Str2; + bool HasStr1 = getConstantStringInfo(Str1P, Str1); + bool HasStr2 = getConstantStringInfo(Str2P, Str2); - return nullptr; + // strncmp(x, y) -> cnst (if both x and y are constant strings) + if (HasStr1 && HasStr2) { + StringRef SubStr1 = Str1.substr(0, Length); + StringRef SubStr2 = Str2.substr(0, Length); + return ConstantInt::get(CI->getType(), SubStr1.compare(SubStr2)); } -}; -struct StrNCmpOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strncmp" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || - !FT->getReturnType()->isIntegerTy(32) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != B.getInt8PtrTy() || - !FT->getParamType(2)->isIntegerTy()) - return nullptr; + if (HasStr1 && Str1.empty()) // strncmp("", x, n) -> -*x + return B.CreateNeg( + B.CreateZExt(B.CreateLoad(Str2P, "strcmpload"), CI->getType())); - Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1); - if (Str1P == Str2P) // strncmp(x,x,n) -> 0 - return ConstantInt::get(CI->getType(), 0); - - // Get the length argument if it is constant. - uint64_t Length; - if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getArgOperand(2))) - Length = LengthArg->getZExtValue(); - else - return nullptr; + if (HasStr2 && Str2.empty()) // strncmp(x, "", n) -> *x + return B.CreateZExt(B.CreateLoad(Str1P, "strcmpload"), CI->getType()); - if (Length == 0) // strncmp(x,y,0) -> 0 - return ConstantInt::get(CI->getType(), 0); - - if (DL && Length == 1) // strncmp(x,y,1) -> memcmp(x,y,1) - return EmitMemCmp(Str1P, Str2P, CI->getArgOperand(2), B, DL, TLI); + return nullptr; +} - StringRef Str1, Str2; - bool HasStr1 = getConstantStringInfo(Str1P, Str1); - bool HasStr2 = getConstantStringInfo(Str2P, Str2); +Value *LibCallSimplifier::optimizeStrCpy(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); - // strncmp(x, y) -> cnst (if both x and y are constant strings) - if (HasStr1 && HasStr2) { - StringRef SubStr1 = Str1.substr(0, Length); - StringRef SubStr2 = Str2.substr(0, Length); - return ConstantInt::get(CI->getType(), SubStr1.compare(SubStr2)); - } + if (!checkStringCopyLibFuncSignature(Callee, LibFunc::strcpy, DL)) + return nullptr; - if (HasStr1 && Str1.empty()) // strncmp("", x, n) -> -*x - return B.CreateNeg(B.CreateZExt(B.CreateLoad(Str2P, "strcmpload"), - CI->getType())); + Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); + if (Dst == Src) // strcpy(x,x) -> x + return Src; - if (HasStr2 && Str2.empty()) // strncmp(x, "", n) -> *x - return B.CreateZExt(B.CreateLoad(Str1P, "strcmpload"), CI->getType()); + // These optimizations require DataLayout. + if (!DL) + return nullptr; + // See if we can get the length of the input string. + uint64_t Len = GetStringLength(Src); + if (Len == 0) return nullptr; - } -}; -struct StrCpyOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strcpy" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != B.getInt8PtrTy()) - return nullptr; + // We have enough information to now generate the memcpy call to do the + // copy for us. Make a memcpy to copy the nul byte with align = 1. + B.CreateMemCpy(Dst, Src, + ConstantInt::get(DL->getIntPtrType(CI->getContext()), Len), 1); + return Dst; +} - Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); - if (Dst == Src) // strcpy(x,x) -> x - return Src; +Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "stpcpy" function prototype. + FunctionType *FT = Callee->getFunctionType(); - // These optimizations require DataLayout. - if (!DL) return nullptr; + if (!checkStringCopyLibFuncSignature(Callee, LibFunc::stpcpy, DL)) + return nullptr; - // See if we can get the length of the input string. - uint64_t Len = GetStringLength(Src); - if (Len == 0) return nullptr; + // These optimizations require DataLayout. + if (!DL) + return nullptr; - // We have enough information to now generate the memcpy call to do the - // copy for us. Make a memcpy to copy the nul byte with align = 1. - B.CreateMemCpy(Dst, Src, - ConstantInt::get(DL->getIntPtrType(*Context), Len), 1); - return Dst; + Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); + if (Dst == Src) { // stpcpy(x,x) -> x+strlen(x) + Value *StrLen = EmitStrLen(Src, B, DL, TLI); + return StrLen ? B.CreateInBoundsGEP(Dst, StrLen) : nullptr; } -}; - -struct StpCpyOpt: public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "stpcpy" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != B.getInt8PtrTy()) - return nullptr; - - // These optimizations require DataLayout. - if (!DL) return nullptr; - - Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); - if (Dst == Src) { // stpcpy(x,x) -> x+strlen(x) - Value *StrLen = EmitStrLen(Src, B, DL, TLI); - return StrLen ? B.CreateInBoundsGEP(Dst, StrLen) : nullptr; - } - - // See if we can get the length of the input string. - uint64_t Len = GetStringLength(Src); - if (Len == 0) return nullptr; - - Type *PT = FT->getParamType(0); - Value *LenV = ConstantInt::get(DL->getIntPtrType(PT), Len); - Value *DstEnd = B.CreateGEP(Dst, - ConstantInt::get(DL->getIntPtrType(PT), - Len - 1)); - - // We have enough information to now generate the memcpy call to do the - // copy for us. Make a memcpy to copy the nul byte with align = 1. - B.CreateMemCpy(Dst, Src, LenV, 1); - return DstEnd; - } -}; - -struct StrNCpyOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != B.getInt8PtrTy() || - !FT->getParamType(2)->isIntegerTy()) - return nullptr; - - Value *Dst = CI->getArgOperand(0); - Value *Src = CI->getArgOperand(1); - Value *LenOp = CI->getArgOperand(2); - // See if we can get the length of the input string. - uint64_t SrcLen = GetStringLength(Src); - if (SrcLen == 0) return nullptr; - --SrcLen; + // See if we can get the length of the input string. + uint64_t Len = GetStringLength(Src); + if (Len == 0) + return nullptr; - if (SrcLen == 0) { - // strncpy(x, "", y) -> memset(x, '\0', y, 1) - B.CreateMemSet(Dst, B.getInt8('\0'), LenOp, 1); - return Dst; - } + Type *PT = FT->getParamType(0); + Value *LenV = ConstantInt::get(DL->getIntPtrType(PT), Len); + Value *DstEnd = + B.CreateGEP(Dst, ConstantInt::get(DL->getIntPtrType(PT), Len - 1)); - uint64_t Len; - if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(LenOp)) - Len = LengthArg->getZExtValue(); - else - return nullptr; + // We have enough information to now generate the memcpy call to do the + // copy for us. Make a memcpy to copy the nul byte with align = 1. + B.CreateMemCpy(Dst, Src, LenV, 1); + return DstEnd; +} - if (Len == 0) return Dst; // strncpy(x, y, 0) -> x +Value *LibCallSimplifier::optimizeStrNCpy(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); - // These optimizations require DataLayout. - if (!DL) return nullptr; + if (!checkStringCopyLibFuncSignature(Callee, LibFunc::strncpy, DL)) + return nullptr; - // Let strncpy handle the zero padding - if (Len > SrcLen+1) return nullptr; + Value *Dst = CI->getArgOperand(0); + Value *Src = CI->getArgOperand(1); + Value *LenOp = CI->getArgOperand(2); - Type *PT = FT->getParamType(0); - // strncpy(x, s, c) -> memcpy(x, s, c, 1) [s and c are constant] - B.CreateMemCpy(Dst, Src, - ConstantInt::get(DL->getIntPtrType(PT), Len), 1); + // See if we can get the length of the input string. + uint64_t SrcLen = GetStringLength(Src); + if (SrcLen == 0) + return nullptr; + --SrcLen; + if (SrcLen == 0) { + // strncpy(x, "", y) -> memset(x, '\0', y, 1) + B.CreateMemSet(Dst, B.getInt8('\0'), LenOp, 1); return Dst; } -}; - -struct StrLenOpt : public LibCallOptimization { - bool ignoreCallingConv() override { return true; } - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 1 || - FT->getParamType(0) != B.getInt8PtrTy() || - !FT->getReturnType()->isIntegerTy()) - return nullptr; - Value *Src = CI->getArgOperand(0); - - // Constant folding: strlen("xyz") -> 3 - if (uint64_t Len = GetStringLength(Src)) - return ConstantInt::get(CI->getType(), Len-1); - - // strlen(x?"foo":"bars") --> x ? 3 : 4 - if (SelectInst *SI = dyn_cast<SelectInst>(Src)) { - uint64_t LenTrue = GetStringLength(SI->getTrueValue()); - uint64_t LenFalse = GetStringLength(SI->getFalseValue()); - if (LenTrue && LenFalse) { - emitOptimizationRemark(*Context, "simplify-libcalls", *Caller, - SI->getDebugLoc(), - "folded strlen(select) to select of constants"); - return B.CreateSelect(SI->getCondition(), - ConstantInt::get(CI->getType(), LenTrue-1), - ConstantInt::get(CI->getType(), LenFalse-1)); - } - } + uint64_t Len; + if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(LenOp)) + Len = LengthArg->getZExtValue(); + else + return nullptr; - // strlen(x) != 0 --> *x != 0 - // strlen(x) == 0 --> *x == 0 - if (isOnlyUsedInZeroEqualityComparison(CI)) - return B.CreateZExt(B.CreateLoad(Src, "strlenfirst"), CI->getType()); + if (Len == 0) + return Dst; // strncpy(x, y, 0) -> x + // These optimizations require DataLayout. + if (!DL) return nullptr; - } -}; -struct StrPBrkOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getParamType(0) != B.getInt8PtrTy() || - FT->getParamType(1) != FT->getParamType(0) || - FT->getReturnType() != FT->getParamType(0)) - return nullptr; + // Let strncpy handle the zero padding + if (Len > SrcLen + 1) + return nullptr; - StringRef S1, S2; - bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); - bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); + Type *PT = FT->getParamType(0); + // strncpy(x, s, c) -> memcpy(x, s, c, 1) [s and c are constant] + B.CreateMemCpy(Dst, Src, ConstantInt::get(DL->getIntPtrType(PT), Len), 1); - // strpbrk(s, "") -> NULL - // strpbrk("", s) -> NULL - if ((HasS1 && S1.empty()) || (HasS2 && S2.empty())) - return Constant::getNullValue(CI->getType()); + return Dst; +} - // Constant folding. - if (HasS1 && HasS2) { - size_t I = S1.find_first_of(S2); - if (I == StringRef::npos) // No match. - return Constant::getNullValue(CI->getType()); +Value *LibCallSimplifier::optimizeStrLen(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 1 || FT->getParamType(0) != B.getInt8PtrTy() || + !FT->getReturnType()->isIntegerTy()) + return nullptr; - return B.CreateGEP(CI->getArgOperand(0), B.getInt64(I), "strpbrk"); + Value *Src = CI->getArgOperand(0); + + // Constant folding: strlen("xyz") -> 3 + if (uint64_t Len = GetStringLength(Src)) + return ConstantInt::get(CI->getType(), Len - 1); + + // strlen(x?"foo":"bars") --> x ? 3 : 4 + if (SelectInst *SI = dyn_cast<SelectInst>(Src)) { + uint64_t LenTrue = GetStringLength(SI->getTrueValue()); + uint64_t LenFalse = GetStringLength(SI->getFalseValue()); + if (LenTrue && LenFalse) { + Function *Caller = CI->getParent()->getParent(); + emitOptimizationRemark(CI->getContext(), "simplify-libcalls", *Caller, + SI->getDebugLoc(), + "folded strlen(select) to select of constants"); + return B.CreateSelect(SI->getCondition(), + ConstantInt::get(CI->getType(), LenTrue - 1), + ConstantInt::get(CI->getType(), LenFalse - 1)); } - - // strpbrk(s, "a") -> strchr(s, 'a') - if (DL && HasS2 && S2.size() == 1) - return EmitStrChr(CI->getArgOperand(0), S2[0], B, DL, TLI); - - return nullptr; } -}; -struct StrToOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if ((FT->getNumParams() != 2 && FT->getNumParams() != 3) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy()) - return nullptr; + // strlen(x) != 0 --> *x != 0 + // strlen(x) == 0 --> *x == 0 + if (isOnlyUsedInZeroEqualityComparison(CI)) + return B.CreateZExt(B.CreateLoad(Src, "strlenfirst"), CI->getType()); - Value *EndPtr = CI->getArgOperand(1); - if (isa<ConstantPointerNull>(EndPtr)) { - // With a null EndPtr, this function won't capture the main argument. - // It would be readonly too, except that it still may write to errno. - CI->addAttribute(1, Attribute::NoCapture); - } + return nullptr; +} +Value *LibCallSimplifier::optimizeStrPBrk(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || FT->getParamType(0) != B.getInt8PtrTy() || + FT->getParamType(1) != FT->getParamType(0) || + FT->getReturnType() != FT->getParamType(0)) return nullptr; - } -}; -struct StrSpnOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getParamType(0) != B.getInt8PtrTy() || - FT->getParamType(1) != FT->getParamType(0) || - !FT->getReturnType()->isIntegerTy()) - return nullptr; + StringRef S1, S2; + bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); + bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); - StringRef S1, S2; - bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); - bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); + // strpbrk(s, "") -> nullptr + // strpbrk("", s) -> nullptr + if ((HasS1 && S1.empty()) || (HasS2 && S2.empty())) + return Constant::getNullValue(CI->getType()); - // strspn(s, "") -> 0 - // strspn("", s) -> 0 - if ((HasS1 && S1.empty()) || (HasS2 && S2.empty())) + // Constant folding. + if (HasS1 && HasS2) { + size_t I = S1.find_first_of(S2); + if (I == StringRef::npos) // No match. return Constant::getNullValue(CI->getType()); - // Constant folding. - if (HasS1 && HasS2) { - size_t Pos = S1.find_first_not_of(S2); - if (Pos == StringRef::npos) Pos = S1.size(); - return ConstantInt::get(CI->getType(), Pos); - } - - return nullptr; + return B.CreateGEP(CI->getArgOperand(0), B.getInt64(I), "strpbrk"); } -}; -struct StrCSpnOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getParamType(0) != B.getInt8PtrTy() || - FT->getParamType(1) != FT->getParamType(0) || - !FT->getReturnType()->isIntegerTy()) - return nullptr; + // strpbrk(s, "a") -> strchr(s, 'a') + if (DL && HasS2 && S2.size() == 1) + return EmitStrChr(CI->getArgOperand(0), S2[0], B, DL, TLI); - StringRef S1, S2; - bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); - bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); + return nullptr; +} - // strcspn("", s) -> 0 - if (HasS1 && S1.empty()) - return Constant::getNullValue(CI->getType()); +Value *LibCallSimplifier::optimizeStrTo(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if ((FT->getNumParams() != 2 && FT->getNumParams() != 3) || + !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy()) + return nullptr; - // Constant folding. - if (HasS1 && HasS2) { - size_t Pos = S1.find_first_of(S2); - if (Pos == StringRef::npos) Pos = S1.size(); - return ConstantInt::get(CI->getType(), Pos); - } + Value *EndPtr = CI->getArgOperand(1); + if (isa<ConstantPointerNull>(EndPtr)) { + // With a null EndPtr, this function won't capture the main argument. + // It would be readonly too, except that it still may write to errno. + CI->addAttribute(1, Attribute::NoCapture); + } - // strcspn(s, "") -> strlen(s) - if (DL && HasS2 && S2.empty()) - return EmitStrLen(CI->getArgOperand(0), B, DL, TLI); + return nullptr; +} +Value *LibCallSimplifier::optimizeStrSpn(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || FT->getParamType(0) != B.getInt8PtrTy() || + FT->getParamType(1) != FT->getParamType(0) || + !FT->getReturnType()->isIntegerTy()) return nullptr; + + StringRef S1, S2; + bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); + bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); + + // strspn(s, "") -> 0 + // strspn("", s) -> 0 + if ((HasS1 && S1.empty()) || (HasS2 && S2.empty())) + return Constant::getNullValue(CI->getType()); + + // Constant folding. + if (HasS1 && HasS2) { + size_t Pos = S1.find_first_not_of(S2); + if (Pos == StringRef::npos) + Pos = S1.size(); + return ConstantInt::get(CI->getType(), Pos); } -}; -struct StrStrOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - !FT->getReturnType()->isPointerTy()) - return nullptr; + return nullptr; +} - // fold strstr(x, x) -> x. - if (CI->getArgOperand(0) == CI->getArgOperand(1)) - return B.CreateBitCast(CI->getArgOperand(0), CI->getType()); +Value *LibCallSimplifier::optimizeStrCSpn(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || FT->getParamType(0) != B.getInt8PtrTy() || + FT->getParamType(1) != FT->getParamType(0) || + !FT->getReturnType()->isIntegerTy()) + return nullptr; - // fold strstr(a, b) == a -> strncmp(a, b, strlen(b)) == 0 - if (DL && isOnlyUsedInEqualityComparison(CI, CI->getArgOperand(0))) { - Value *StrLen = EmitStrLen(CI->getArgOperand(1), B, DL, TLI); - if (!StrLen) - return nullptr; - Value *StrNCmp = EmitStrNCmp(CI->getArgOperand(0), CI->getArgOperand(1), - StrLen, B, DL, TLI); - if (!StrNCmp) - return nullptr; - for (auto UI = CI->user_begin(), UE = CI->user_end(); UI != UE;) { - ICmpInst *Old = cast<ICmpInst>(*UI++); - Value *Cmp = B.CreateICmp(Old->getPredicate(), StrNCmp, - ConstantInt::getNullValue(StrNCmp->getType()), - "cmp"); - LCS->replaceAllUsesWith(Old, Cmp); - } - return CI; - } + StringRef S1, S2; + bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); + bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); - // See if either input string is a constant string. - StringRef SearchStr, ToFindStr; - bool HasStr1 = getConstantStringInfo(CI->getArgOperand(0), SearchStr); - bool HasStr2 = getConstantStringInfo(CI->getArgOperand(1), ToFindStr); + // strcspn("", s) -> 0 + if (HasS1 && S1.empty()) + return Constant::getNullValue(CI->getType()); - // fold strstr(x, "") -> x. - if (HasStr2 && ToFindStr.empty()) - return B.CreateBitCast(CI->getArgOperand(0), CI->getType()); + // Constant folding. + if (HasS1 && HasS2) { + size_t Pos = S1.find_first_of(S2); + if (Pos == StringRef::npos) + Pos = S1.size(); + return ConstantInt::get(CI->getType(), Pos); + } - // If both strings are known, constant fold it. - if (HasStr1 && HasStr2) { - size_t Offset = SearchStr.find(ToFindStr); + // strcspn(s, "") -> strlen(s) + if (DL && HasS2 && S2.empty()) + return EmitStrLen(CI->getArgOperand(0), B, DL, TLI); - if (Offset == StringRef::npos) // strstr("foo", "bar") -> null - return Constant::getNullValue(CI->getType()); + return nullptr; +} - // strstr("abcd", "bc") -> gep((char*)"abcd", 1) - Value *Result = CastToCStr(CI->getArgOperand(0), B); - Result = B.CreateConstInBoundsGEP1_64(Result, Offset, "strstr"); - return B.CreateBitCast(Result, CI->getType()); - } +Value *LibCallSimplifier::optimizeStrStr(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || + !FT->getReturnType()->isPointerTy()) + return nullptr; + + // fold strstr(x, x) -> x. + if (CI->getArgOperand(0) == CI->getArgOperand(1)) + return B.CreateBitCast(CI->getArgOperand(0), CI->getType()); - // fold strstr(x, "y") -> strchr(x, 'y'). - if (HasStr2 && ToFindStr.size() == 1) { - Value *StrChr= EmitStrChr(CI->getArgOperand(0), ToFindStr[0], B, DL, TLI); - return StrChr ? B.CreateBitCast(StrChr, CI->getType()) : nullptr; + // fold strstr(a, b) == a -> strncmp(a, b, strlen(b)) == 0 + if (DL && isOnlyUsedInEqualityComparison(CI, CI->getArgOperand(0))) { + Value *StrLen = EmitStrLen(CI->getArgOperand(1), B, DL, TLI); + if (!StrLen) + return nullptr; + Value *StrNCmp = EmitStrNCmp(CI->getArgOperand(0), CI->getArgOperand(1), + StrLen, B, DL, TLI); + if (!StrNCmp) + return nullptr; + for (auto UI = CI->user_begin(), UE = CI->user_end(); UI != UE;) { + ICmpInst *Old = cast<ICmpInst>(*UI++); + Value *Cmp = + B.CreateICmp(Old->getPredicate(), StrNCmp, + ConstantInt::getNullValue(StrNCmp->getType()), "cmp"); + replaceAllUsesWith(Old, Cmp); } - return nullptr; + return CI; } -}; -struct MemCmpOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - !FT->getReturnType()->isIntegerTy(32)) - return nullptr; + // See if either input string is a constant string. + StringRef SearchStr, ToFindStr; + bool HasStr1 = getConstantStringInfo(CI->getArgOperand(0), SearchStr); + bool HasStr2 = getConstantStringInfo(CI->getArgOperand(1), ToFindStr); + + // fold strstr(x, "") -> x. + if (HasStr2 && ToFindStr.empty()) + return B.CreateBitCast(CI->getArgOperand(0), CI->getType()); - Value *LHS = CI->getArgOperand(0), *RHS = CI->getArgOperand(1); + // If both strings are known, constant fold it. + if (HasStr1 && HasStr2) { + size_t Offset = SearchStr.find(ToFindStr); - if (LHS == RHS) // memcmp(s,s,x) -> 0 + if (Offset == StringRef::npos) // strstr("foo", "bar") -> null return Constant::getNullValue(CI->getType()); - // Make sure we have a constant length. - ConstantInt *LenC = dyn_cast<ConstantInt>(CI->getArgOperand(2)); - if (!LenC) return nullptr; - uint64_t Len = LenC->getZExtValue(); + // strstr("abcd", "bc") -> gep((char*)"abcd", 1) + Value *Result = CastToCStr(CI->getArgOperand(0), B); + Result = B.CreateConstInBoundsGEP1_64(Result, Offset, "strstr"); + return B.CreateBitCast(Result, CI->getType()); + } - if (Len == 0) // memcmp(s1,s2,0) -> 0 - return Constant::getNullValue(CI->getType()); + // fold strstr(x, "y") -> strchr(x, 'y'). + if (HasStr2 && ToFindStr.size() == 1) { + Value *StrChr = EmitStrChr(CI->getArgOperand(0), ToFindStr[0], B, DL, TLI); + return StrChr ? B.CreateBitCast(StrChr, CI->getType()) : nullptr; + } + return nullptr; +} - // memcmp(S1,S2,1) -> *(unsigned char*)LHS - *(unsigned char*)RHS - if (Len == 1) { - Value *LHSV = B.CreateZExt(B.CreateLoad(CastToCStr(LHS, B), "lhsc"), - CI->getType(), "lhsv"); - Value *RHSV = B.CreateZExt(B.CreateLoad(CastToCStr(RHS, B), "rhsc"), - CI->getType(), "rhsv"); - return B.CreateSub(LHSV, RHSV, "chardiff"); - } +Value *LibCallSimplifier::optimizeMemCmp(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || + !FT->getReturnType()->isIntegerTy(32)) + return nullptr; - // Constant folding: memcmp(x, y, l) -> cnst (all arguments are constant) - StringRef LHSStr, RHSStr; - if (getConstantStringInfo(LHS, LHSStr) && - getConstantStringInfo(RHS, RHSStr)) { - // Make sure we're not reading out-of-bounds memory. - if (Len > LHSStr.size() || Len > RHSStr.size()) - return nullptr; - // Fold the memcmp and normalize the result. This way we get consistent - // results across multiple platforms. - uint64_t Ret = 0; - int Cmp = memcmp(LHSStr.data(), RHSStr.data(), Len); - if (Cmp < 0) - Ret = -1; - else if (Cmp > 0) - Ret = 1; - return ConstantInt::get(CI->getType(), Ret); - } + Value *LHS = CI->getArgOperand(0), *RHS = CI->getArgOperand(1); + if (LHS == RHS) // memcmp(s,s,x) -> 0 + return Constant::getNullValue(CI->getType()); + + // Make sure we have a constant length. + ConstantInt *LenC = dyn_cast<ConstantInt>(CI->getArgOperand(2)); + if (!LenC) return nullptr; + uint64_t Len = LenC->getZExtValue(); + + if (Len == 0) // memcmp(s1,s2,0) -> 0 + return Constant::getNullValue(CI->getType()); + + // memcmp(S1,S2,1) -> *(unsigned char*)LHS - *(unsigned char*)RHS + if (Len == 1) { + Value *LHSV = B.CreateZExt(B.CreateLoad(CastToCStr(LHS, B), "lhsc"), + CI->getType(), "lhsv"); + Value *RHSV = B.CreateZExt(B.CreateLoad(CastToCStr(RHS, B), "rhsc"), + CI->getType(), "rhsv"); + return B.CreateSub(LHSV, RHSV, "chardiff"); + } + + // Constant folding: memcmp(x, y, l) -> cnst (all arguments are constant) + StringRef LHSStr, RHSStr; + if (getConstantStringInfo(LHS, LHSStr) && + getConstantStringInfo(RHS, RHSStr)) { + // Make sure we're not reading out-of-bounds memory. + if (Len > LHSStr.size() || Len > RHSStr.size()) + return nullptr; + // Fold the memcmp and normalize the result. This way we get consistent + // results across multiple platforms. + uint64_t Ret = 0; + int Cmp = memcmp(LHSStr.data(), RHSStr.data(), Len); + if (Cmp < 0) + Ret = -1; + else if (Cmp > 0) + Ret = 1; + return ConstantInt::get(CI->getType(), Ret); } -}; -struct MemCpyOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // These optimizations require DataLayout. - if (!DL) return nullptr; + return nullptr; +} - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - FT->getParamType(2) != DL->getIntPtrType(*Context)) - return nullptr; +Value *LibCallSimplifier::optimizeMemCpy(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // These optimizations require DataLayout. + if (!DL) + return nullptr; - // memcpy(x, y, n) -> llvm.memcpy(x, y, n, 1) - B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); - return CI->getArgOperand(0); - } -}; + if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memcpy, DL)) + return nullptr; -struct MemMoveOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // These optimizations require DataLayout. - if (!DL) return nullptr; + // memcpy(x, y, n) -> llvm.memcpy(x, y, n, 1) + B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), 1); + return CI->getArgOperand(0); +} - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - FT->getParamType(2) != DL->getIntPtrType(*Context)) - return nullptr; +Value *LibCallSimplifier::optimizeMemMove(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // These optimizations require DataLayout. + if (!DL) + return nullptr; - // memmove(x, y, n) -> llvm.memmove(x, y, n, 1) - B.CreateMemMove(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); - return CI->getArgOperand(0); - } -}; + if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memmove, DL)) + return nullptr; -struct MemSetOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // These optimizations require DataLayout. - if (!DL) return nullptr; + // memmove(x, y, n) -> llvm.memmove(x, y, n, 1) + B.CreateMemMove(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), 1); + return CI->getArgOperand(0); +} - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isIntegerTy() || - FT->getParamType(2) != DL->getIntPtrType(FT->getParamType(0))) - return nullptr; +Value *LibCallSimplifier::optimizeMemSet(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // These optimizations require DataLayout. + if (!DL) + return nullptr; - // memset(p, v, n) -> llvm.memset(p, v, n, 1) - Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); - B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); - return CI->getArgOperand(0); - } -}; + if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memset, DL)) + return nullptr; + + // memset(p, v, n) -> llvm.memset(p, v, n, 1) + Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); + B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); + return CI->getArgOperand(0); +} //===----------------------------------------------------------------------===// // Math Library Optimizations //===----------------------------------------------------------------------===// +/// Return a variant of Val with float type. +/// Currently this works in two cases: If Val is an FPExtension of a float +/// value to something bigger, simply return the operand. +/// If Val is a ConstantFP but can be converted to a float ConstantFP without +/// loss of precision do so. +static Value *valueHasFloatPrecision(Value *Val) { + if (FPExtInst *Cast = dyn_cast<FPExtInst>(Val)) { + Value *Op = Cast->getOperand(0); + if (Op->getType()->isFloatTy()) + return Op; + } + if (ConstantFP *Const = dyn_cast<ConstantFP>(Val)) { + APFloat F = Const->getValueAPF(); + bool losesInfo; + (void)F.convert(APFloat::IEEEsingle, APFloat::rmNearestTiesToEven, + &losesInfo); + if (!losesInfo) + return ConstantFP::get(Const->getContext(), F); + } + return nullptr; +} + //===----------------------------------------------------------------------===// // Double -> Float Shrinking Optimizations for Unary Functions like 'floor' -struct UnaryDoubleFPOpt : public LibCallOptimization { - bool CheckRetType; - UnaryDoubleFPOpt(bool CheckReturnType): CheckRetType(CheckReturnType) {} - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 1 || !FT->getReturnType()->isDoubleTy() || - !FT->getParamType(0)->isDoubleTy()) - return nullptr; +Value *LibCallSimplifier::optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B, + bool CheckRetType) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 1 || !FT->getReturnType()->isDoubleTy() || + !FT->getParamType(0)->isDoubleTy()) + return nullptr; - if (CheckRetType) { - // Check if all the uses for function like 'sin' are converted to float. - for (User *U : CI->users()) { - FPTruncInst *Cast = dyn_cast<FPTruncInst>(U); - if (!Cast || !Cast->getType()->isFloatTy()) - return nullptr; - } + if (CheckRetType) { + // Check if all the uses for function like 'sin' are converted to float. + for (User *U : CI->users()) { + FPTruncInst *Cast = dyn_cast<FPTruncInst>(U); + if (!Cast || !Cast->getType()->isFloatTy()) + return nullptr; } + } - // If this is something like 'floor((double)floatval)', convert to floorf. - FPExtInst *Cast = dyn_cast<FPExtInst>(CI->getArgOperand(0)); - if (!Cast || !Cast->getOperand(0)->getType()->isFloatTy()) - return nullptr; + // If this is something like 'floor((double)floatval)', convert to floorf. + Value *V = valueHasFloatPrecision(CI->getArgOperand(0)); + if (V == nullptr) + return nullptr; - // floor((double)floatval) -> (double)floorf(floatval) - Value *V = Cast->getOperand(0); + // floor((double)floatval) -> (double)floorf(floatval) + if (Callee->isIntrinsic()) { + Module *M = CI->getParent()->getParent()->getParent(); + Intrinsic::ID IID = (Intrinsic::ID) Callee->getIntrinsicID(); + Function *F = Intrinsic::getDeclaration(M, IID, B.getFloatTy()); + V = B.CreateCall(F, V); + } else { + // The call is a library call rather than an intrinsic. V = EmitUnaryFloatFnCall(V, Callee->getName(), B, Callee->getAttributes()); - return B.CreateFPExt(V, B.getDoubleTy()); } -}; + + return B.CreateFPExt(V, B.getDoubleTy()); +} // Double -> Float Shrinking Optimizations for Binary Functions like 'fmin/fmax' -struct BinaryDoubleFPOpt : public LibCallOptimization { - bool CheckRetType; - BinaryDoubleFPOpt(bool CheckReturnType): CheckRetType(CheckReturnType) {} - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - // Just make sure this has 2 arguments of the same FP type, which match the - // result type. - if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - !FT->getParamType(0)->isFloatingPointTy()) - return nullptr; +Value *LibCallSimplifier::optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + // Just make sure this has 2 arguments of the same FP type, which match the + // result type. + if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) || + FT->getParamType(0) != FT->getParamType(1) || + !FT->getParamType(0)->isFloatingPointTy()) + return nullptr; - if (CheckRetType) { - // Check if all the uses for function like 'fmin/fmax' are converted to - // float. - for (User *U : CI->users()) { - FPTruncInst *Cast = dyn_cast<FPTruncInst>(U); - if (!Cast || !Cast->getType()->isFloatTy()) - return nullptr; - } - } + // If this is something like 'fmin((double)floatval1, (double)floatval2)', + // or fmin(1.0, (double)floatval), then we convert it to fminf. + Value *V1 = valueHasFloatPrecision(CI->getArgOperand(0)); + if (V1 == nullptr) + return nullptr; + Value *V2 = valueHasFloatPrecision(CI->getArgOperand(1)); + if (V2 == nullptr) + return nullptr; - // If this is something like 'fmin((double)floatval1, (double)floatval2)', - // we convert it to fminf. - FPExtInst *Cast1 = dyn_cast<FPExtInst>(CI->getArgOperand(0)); - FPExtInst *Cast2 = dyn_cast<FPExtInst>(CI->getArgOperand(1)); - if (!Cast1 || !Cast1->getOperand(0)->getType()->isFloatTy() || - !Cast2 || !Cast2->getOperand(0)->getType()->isFloatTy()) - return nullptr; + // fmin((double)floatval1, (double)floatval2) + // -> (double)fminf(floatval1, floatval2) + // TODO: Handle intrinsics in the same way as in optimizeUnaryDoubleFP(). + Value *V = EmitBinaryFloatFnCall(V1, V2, Callee->getName(), B, + Callee->getAttributes()); + return B.CreateFPExt(V, B.getDoubleTy()); +} - // fmin((double)floatval1, (double)floatval2) - // -> (double)fmin(floatval1, floatval2) - Value *V = nullptr; - Value *V1 = Cast1->getOperand(0); - Value *V2 = Cast2->getOperand(0); - V = EmitBinaryFloatFnCall(V1, V2, Callee->getName(), B, - Callee->getAttributes()); - return B.CreateFPExt(V, B.getDoubleTy()); - } -}; - -struct UnsafeFPLibCallOptimization : public LibCallOptimization { - bool UnsafeFPShrink; - UnsafeFPLibCallOptimization(bool UnsafeFPShrink) { - this->UnsafeFPShrink = UnsafeFPShrink; - } -}; - -struct CosOpt : public UnsafeFPLibCallOptimization { - CosOpt(bool UnsafeFPShrink) : UnsafeFPLibCallOptimization(UnsafeFPShrink) {} - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - Value *Ret = nullptr; - if (UnsafeFPShrink && Callee->getName() == "cos" && - TLI->has(LibFunc::cosf)) { - UnaryDoubleFPOpt UnsafeUnaryDoubleFP(true); - Ret = UnsafeUnaryDoubleFP.callOptimizer(Callee, CI, B); - } +Value *LibCallSimplifier::optimizeCos(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + Value *Ret = nullptr; + if (UnsafeFPShrink && Callee->getName() == "cos" && TLI->has(LibFunc::cosf)) { + Ret = optimizeUnaryDoubleFP(CI, B, true); + } - FunctionType *FT = Callee->getFunctionType(); - // Just make sure this has 1 argument of FP type, which matches the - // result type. - if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isFloatingPointTy()) - return Ret; + FunctionType *FT = Callee->getFunctionType(); + // Just make sure this has 1 argument of FP type, which matches the + // result type. + if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isFloatingPointTy()) + return Ret; - // cos(-x) -> cos(x) - Value *Op1 = CI->getArgOperand(0); - if (BinaryOperator::isFNeg(Op1)) { - BinaryOperator *BinExpr = cast<BinaryOperator>(Op1); - return B.CreateCall(Callee, BinExpr->getOperand(1), "cos"); - } + // cos(-x) -> cos(x) + Value *Op1 = CI->getArgOperand(0); + if (BinaryOperator::isFNeg(Op1)) { + BinaryOperator *BinExpr = cast<BinaryOperator>(Op1); + return B.CreateCall(Callee, BinExpr->getOperand(1), "cos"); + } + return Ret; +} + +Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + + Value *Ret = nullptr; + if (UnsafeFPShrink && Callee->getName() == "pow" && TLI->has(LibFunc::powf)) { + Ret = optimizeUnaryDoubleFP(CI, B, true); + } + + FunctionType *FT = Callee->getFunctionType(); + // Just make sure this has 2 arguments of the same FP type, which match the + // result type. + if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) || + FT->getParamType(0) != FT->getParamType(1) || + !FT->getParamType(0)->isFloatingPointTy()) return Ret; + + Value *Op1 = CI->getArgOperand(0), *Op2 = CI->getArgOperand(1); + if (ConstantFP *Op1C = dyn_cast<ConstantFP>(Op1)) { + // pow(1.0, x) -> 1.0 + if (Op1C->isExactlyValue(1.0)) + return Op1C; + // pow(2.0, x) -> exp2(x) + if (Op1C->isExactlyValue(2.0) && + hasUnaryFloatFn(TLI, Op1->getType(), LibFunc::exp2, LibFunc::exp2f, + LibFunc::exp2l)) + return EmitUnaryFloatFnCall(Op2, "exp2", B, Callee->getAttributes()); + // pow(10.0, x) -> exp10(x) + if (Op1C->isExactlyValue(10.0) && + hasUnaryFloatFn(TLI, Op1->getType(), LibFunc::exp10, LibFunc::exp10f, + LibFunc::exp10l)) + return EmitUnaryFloatFnCall(Op2, TLI->getName(LibFunc::exp10), B, + Callee->getAttributes()); + } + + ConstantFP *Op2C = dyn_cast<ConstantFP>(Op2); + if (!Op2C) + return Ret; + + if (Op2C->getValueAPF().isZero()) // pow(x, 0.0) -> 1.0 + return ConstantFP::get(CI->getType(), 1.0); + + if (Op2C->isExactlyValue(0.5) && + hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::sqrt, LibFunc::sqrtf, + LibFunc::sqrtl) && + hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::fabs, LibFunc::fabsf, + LibFunc::fabsl)) { + // Expand pow(x, 0.5) to (x == -infinity ? +infinity : fabs(sqrt(x))). + // This is faster than calling pow, and still handles negative zero + // and negative infinity correctly. + // TODO: In fast-math mode, this could be just sqrt(x). + // TODO: In finite-only mode, this could be just fabs(sqrt(x)). + Value *Inf = ConstantFP::getInfinity(CI->getType()); + Value *NegInf = ConstantFP::getInfinity(CI->getType(), true); + Value *Sqrt = EmitUnaryFloatFnCall(Op1, "sqrt", B, Callee->getAttributes()); + Value *FAbs = + EmitUnaryFloatFnCall(Sqrt, "fabs", B, Callee->getAttributes()); + Value *FCmp = B.CreateFCmpOEQ(Op1, NegInf); + Value *Sel = B.CreateSelect(FCmp, Inf, FAbs); + return Sel; + } + + if (Op2C->isExactlyValue(1.0)) // pow(x, 1.0) -> x + return Op1; + if (Op2C->isExactlyValue(2.0)) // pow(x, 2.0) -> x*x + return B.CreateFMul(Op1, Op1, "pow2"); + if (Op2C->isExactlyValue(-1.0)) // pow(x, -1.0) -> 1.0/x + return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), Op1, "powrecip"); + return nullptr; +} + +Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + Function *Caller = CI->getParent()->getParent(); + + Value *Ret = nullptr; + if (UnsafeFPShrink && Callee->getName() == "exp2" && + TLI->has(LibFunc::exp2f)) { + Ret = optimizeUnaryDoubleFP(CI, B, true); } -}; - -struct PowOpt : public UnsafeFPLibCallOptimization { - PowOpt(bool UnsafeFPShrink) : UnsafeFPLibCallOptimization(UnsafeFPShrink) {} - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - Value *Ret = nullptr; - if (UnsafeFPShrink && Callee->getName() == "pow" && - TLI->has(LibFunc::powf)) { - UnaryDoubleFPOpt UnsafeUnaryDoubleFP(true); - Ret = UnsafeUnaryDoubleFP.callOptimizer(Callee, CI, B); - } - FunctionType *FT = Callee->getFunctionType(); - // Just make sure this has 2 arguments of the same FP type, which match the - // result type. - if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - !FT->getParamType(0)->isFloatingPointTy()) - return Ret; + FunctionType *FT = Callee->getFunctionType(); + // Just make sure this has 1 argument of FP type, which matches the + // result type. + if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isFloatingPointTy()) + return Ret; - Value *Op1 = CI->getArgOperand(0), *Op2 = CI->getArgOperand(1); - if (ConstantFP *Op1C = dyn_cast<ConstantFP>(Op1)) { - // pow(1.0, x) -> 1.0 - if (Op1C->isExactlyValue(1.0)) - return Op1C; - // pow(2.0, x) -> exp2(x) - if (Op1C->isExactlyValue(2.0) && - hasUnaryFloatFn(TLI, Op1->getType(), LibFunc::exp2, LibFunc::exp2f, - LibFunc::exp2l)) - return EmitUnaryFloatFnCall(Op2, "exp2", B, Callee->getAttributes()); - // pow(10.0, x) -> exp10(x) - if (Op1C->isExactlyValue(10.0) && - hasUnaryFloatFn(TLI, Op1->getType(), LibFunc::exp10, LibFunc::exp10f, - LibFunc::exp10l)) - return EmitUnaryFloatFnCall(Op2, TLI->getName(LibFunc::exp10), B, - Callee->getAttributes()); + Value *Op = CI->getArgOperand(0); + // Turn exp2(sitofp(x)) -> ldexp(1.0, sext(x)) if sizeof(x) <= 32 + // Turn exp2(uitofp(x)) -> ldexp(1.0, zext(x)) if sizeof(x) < 32 + LibFunc::Func LdExp = LibFunc::ldexpl; + if (Op->getType()->isFloatTy()) + LdExp = LibFunc::ldexpf; + else if (Op->getType()->isDoubleTy()) + LdExp = LibFunc::ldexp; + + if (TLI->has(LdExp)) { + Value *LdExpArg = nullptr; + if (SIToFPInst *OpC = dyn_cast<SIToFPInst>(Op)) { + if (OpC->getOperand(0)->getType()->getPrimitiveSizeInBits() <= 32) + LdExpArg = B.CreateSExt(OpC->getOperand(0), B.getInt32Ty()); + } else if (UIToFPInst *OpC = dyn_cast<UIToFPInst>(Op)) { + if (OpC->getOperand(0)->getType()->getPrimitiveSizeInBits() < 32) + LdExpArg = B.CreateZExt(OpC->getOperand(0), B.getInt32Ty()); } - ConstantFP *Op2C = dyn_cast<ConstantFP>(Op2); - if (!Op2C) return Ret; - - if (Op2C->getValueAPF().isZero()) // pow(x, 0.0) -> 1.0 - return ConstantFP::get(CI->getType(), 1.0); - - if (Op2C->isExactlyValue(0.5) && - hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::sqrt, LibFunc::sqrtf, - LibFunc::sqrtl) && - hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::fabs, LibFunc::fabsf, - LibFunc::fabsl)) { - // Expand pow(x, 0.5) to (x == -infinity ? +infinity : fabs(sqrt(x))). - // This is faster than calling pow, and still handles negative zero - // and negative infinity correctly. - // TODO: In fast-math mode, this could be just sqrt(x). - // TODO: In finite-only mode, this could be just fabs(sqrt(x)). - Value *Inf = ConstantFP::getInfinity(CI->getType()); - Value *NegInf = ConstantFP::getInfinity(CI->getType(), true); - Value *Sqrt = EmitUnaryFloatFnCall(Op1, "sqrt", B, - Callee->getAttributes()); - Value *FAbs = EmitUnaryFloatFnCall(Sqrt, "fabs", B, - Callee->getAttributes()); - Value *FCmp = B.CreateFCmpOEQ(Op1, NegInf); - Value *Sel = B.CreateSelect(FCmp, Inf, FAbs); - return Sel; - } + if (LdExpArg) { + Constant *One = ConstantFP::get(CI->getContext(), APFloat(1.0f)); + if (!Op->getType()->isFloatTy()) + One = ConstantExpr::getFPExtend(One, Op->getType()); + + Module *M = Caller->getParent(); + Value *Callee = + M->getOrInsertFunction(TLI->getName(LdExp), Op->getType(), + Op->getType(), B.getInt32Ty(), nullptr); + CallInst *CI = B.CreateCall2(Callee, One, LdExpArg); + if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) + CI->setCallingConv(F->getCallingConv()); - if (Op2C->isExactlyValue(1.0)) // pow(x, 1.0) -> x - return Op1; - if (Op2C->isExactlyValue(2.0)) // pow(x, 2.0) -> x*x - return B.CreateFMul(Op1, Op1, "pow2"); - if (Op2C->isExactlyValue(-1.0)) // pow(x, -1.0) -> 1.0/x - return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), - Op1, "powrecip"); - return nullptr; - } -}; - -struct Exp2Opt : public UnsafeFPLibCallOptimization { - Exp2Opt(bool UnsafeFPShrink) : UnsafeFPLibCallOptimization(UnsafeFPShrink) {} - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - Value *Ret = nullptr; - if (UnsafeFPShrink && Callee->getName() == "exp2" && - TLI->has(LibFunc::exp2f)) { - UnaryDoubleFPOpt UnsafeUnaryDoubleFP(true); - Ret = UnsafeUnaryDoubleFP.callOptimizer(Callee, CI, B); + return CI; } + } + return Ret; +} - FunctionType *FT = Callee->getFunctionType(); - // Just make sure this has 1 argument of FP type, which matches the - // result type. - if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isFloatingPointTy()) - return Ret; +Value *LibCallSimplifier::optimizeFabs(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); - Value *Op = CI->getArgOperand(0); - // Turn exp2(sitofp(x)) -> ldexp(1.0, sext(x)) if sizeof(x) <= 32 - // Turn exp2(uitofp(x)) -> ldexp(1.0, zext(x)) if sizeof(x) < 32 - LibFunc::Func LdExp = LibFunc::ldexpl; - if (Op->getType()->isFloatTy()) - LdExp = LibFunc::ldexpf; - else if (Op->getType()->isDoubleTy()) - LdExp = LibFunc::ldexp; - - if (TLI->has(LdExp)) { - Value *LdExpArg = nullptr; - if (SIToFPInst *OpC = dyn_cast<SIToFPInst>(Op)) { - if (OpC->getOperand(0)->getType()->getPrimitiveSizeInBits() <= 32) - LdExpArg = B.CreateSExt(OpC->getOperand(0), B.getInt32Ty()); - } else if (UIToFPInst *OpC = dyn_cast<UIToFPInst>(Op)) { - if (OpC->getOperand(0)->getType()->getPrimitiveSizeInBits() < 32) - LdExpArg = B.CreateZExt(OpC->getOperand(0), B.getInt32Ty()); - } + Value *Ret = nullptr; + if (Callee->getName() == "fabs" && TLI->has(LibFunc::fabsf)) { + Ret = optimizeUnaryDoubleFP(CI, B, false); + } - if (LdExpArg) { - Constant *One = ConstantFP::get(*Context, APFloat(1.0f)); - if (!Op->getType()->isFloatTy()) - One = ConstantExpr::getFPExtend(One, Op->getType()); + FunctionType *FT = Callee->getFunctionType(); + // Make sure this has 1 argument of FP type which matches the result type. + if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isFloatingPointTy()) + return Ret; - Module *M = Caller->getParent(); - Value *Callee = - M->getOrInsertFunction(TLI->getName(LdExp), Op->getType(), - Op->getType(), B.getInt32Ty(), NULL); - CallInst *CI = B.CreateCall2(Callee, One, LdExpArg); - if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) - CI->setCallingConv(F->getCallingConv()); + Value *Op = CI->getArgOperand(0); + if (Instruction *I = dyn_cast<Instruction>(Op)) { + // Fold fabs(x * x) -> x * x; any squared FP value must already be positive. + if (I->getOpcode() == Instruction::FMul) + if (I->getOperand(0) == I->getOperand(1)) + return Op; + } + return Ret; +} - return CI; +Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + + Value *Ret = nullptr; + if (TLI->has(LibFunc::sqrtf) && (Callee->getName() == "sqrt" || + Callee->getIntrinsicID() == Intrinsic::sqrt)) + Ret = optimizeUnaryDoubleFP(CI, B, true); + + // FIXME: For finer-grain optimization, we need intrinsics to have the same + // fast-math flag decorations that are applied to FP instructions. For now, + // we have to rely on the function-level unsafe-fp-math attribute to do this + // optimization because there's no other way to express that the sqrt can be + // reassociated. + Function *F = CI->getParent()->getParent(); + if (F->hasFnAttribute("unsafe-fp-math")) { + // Check for unsafe-fp-math = true. + Attribute Attr = F->getFnAttribute("unsafe-fp-math"); + if (Attr.getValueAsString() != "true") + return Ret; + } + Value *Op = CI->getArgOperand(0); + if (Instruction *I = dyn_cast<Instruction>(Op)) { + if (I->getOpcode() == Instruction::FMul && I->hasUnsafeAlgebra()) { + // We're looking for a repeated factor in a multiplication tree, + // so we can do this fold: sqrt(x * x) -> fabs(x); + // or this fold: sqrt(x * x * y) -> fabs(x) * sqrt(y). + Value *Op0 = I->getOperand(0); + Value *Op1 = I->getOperand(1); + Value *RepeatOp = nullptr; + Value *OtherOp = nullptr; + if (Op0 == Op1) { + // Simple match: the operands of the multiply are identical. + RepeatOp = Op0; + } else { + // Look for a more complicated pattern: one of the operands is itself + // a multiply, so search for a common factor in that multiply. + // Note: We don't bother looking any deeper than this first level or for + // variations of this pattern because instcombine's visitFMUL and/or the + // reassociation pass should give us this form. + Value *OtherMul0, *OtherMul1; + if (match(Op0, m_FMul(m_Value(OtherMul0), m_Value(OtherMul1)))) { + // Pattern: sqrt((x * y) * z) + if (OtherMul0 == OtherMul1) { + // Matched: sqrt((x * x) * z) + RepeatOp = OtherMul0; + OtherOp = Op1; + } + } + } + if (RepeatOp) { + // Fast math flags for any created instructions should match the sqrt + // and multiply. + // FIXME: We're not checking the sqrt because it doesn't have + // fast-math-flags (see earlier comment). + IRBuilder<true, ConstantFolder, + IRBuilderDefaultInserter<true> >::FastMathFlagGuard Guard(B); + B.SetFastMathFlags(I->getFastMathFlags()); + // If we found a repeated factor, hoist it out of the square root and + // replace it with the fabs of that factor. + Module *M = Callee->getParent(); + Type *ArgType = Op->getType(); + Value *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, ArgType); + Value *FabsCall = B.CreateCall(Fabs, RepeatOp, "fabs"); + if (OtherOp) { + // If we found a non-repeated factor, we still need to get its square + // root. We then multiply that by the value that was simplified out + // of the square root calculation. + Value *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, ArgType); + Value *SqrtCall = B.CreateCall(Sqrt, OtherOp, "sqrt"); + return B.CreateFMul(FabsCall, SqrtCall); + } + return FabsCall; } } - return Ret; } -}; + return Ret; +} -struct SinCosPiOpt : public LibCallOptimization { - SinCosPiOpt() {} +static bool isTrigLibCall(CallInst *CI); +static void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg, + bool UseFloat, Value *&Sin, Value *&Cos, + Value *&SinCos); - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Make sure the prototype is as expected, otherwise the rest of the - // function is probably invalid and likely to abort. - if (!isTrigLibCall(CI)) - return nullptr; +Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, IRBuilder<> &B) { - Value *Arg = CI->getArgOperand(0); - SmallVector<CallInst *, 1> SinCalls; - SmallVector<CallInst *, 1> CosCalls; - SmallVector<CallInst *, 1> SinCosCalls; + // Make sure the prototype is as expected, otherwise the rest of the + // function is probably invalid and likely to abort. + if (!isTrigLibCall(CI)) + return nullptr; - bool IsFloat = Arg->getType()->isFloatTy(); + Value *Arg = CI->getArgOperand(0); + SmallVector<CallInst *, 1> SinCalls; + SmallVector<CallInst *, 1> CosCalls; + SmallVector<CallInst *, 1> SinCosCalls; - // Look for all compatible sinpi, cospi and sincospi calls with the same - // argument. If there are enough (in some sense) we can make the - // substitution. - for (User *U : Arg->users()) - classifyArgUse(U, CI->getParent(), IsFloat, SinCalls, CosCalls, - SinCosCalls); + bool IsFloat = Arg->getType()->isFloatTy(); - // It's only worthwhile if both sinpi and cospi are actually used. - if (SinCosCalls.empty() && (SinCalls.empty() || CosCalls.empty())) - return nullptr; + // Look for all compatible sinpi, cospi and sincospi calls with the same + // argument. If there are enough (in some sense) we can make the + // substitution. + for (User *U : Arg->users()) + classifyArgUse(U, CI->getParent(), IsFloat, SinCalls, CosCalls, + SinCosCalls); - Value *Sin, *Cos, *SinCos; - insertSinCosCall(B, CI->getCalledFunction(), Arg, IsFloat, Sin, Cos, - SinCos); - - replaceTrigInsts(SinCalls, Sin); - replaceTrigInsts(CosCalls, Cos); - replaceTrigInsts(SinCosCalls, SinCos); - - return nullptr; - } - - bool isTrigLibCall(CallInst *CI) { - Function *Callee = CI->getCalledFunction(); - FunctionType *FT = Callee->getFunctionType(); - - // We can only hope to do anything useful if we can ignore things like errno - // and floating-point exceptions. - bool AttributesSafe = CI->hasFnAttr(Attribute::NoUnwind) && - CI->hasFnAttr(Attribute::ReadNone); - - // Other than that we need float(float) or double(double) - return AttributesSafe && FT->getNumParams() == 1 && - FT->getReturnType() == FT->getParamType(0) && - (FT->getParamType(0)->isFloatTy() || - FT->getParamType(0)->isDoubleTy()); - } - - void classifyArgUse(Value *Val, BasicBlock *BB, bool IsFloat, - SmallVectorImpl<CallInst *> &SinCalls, - SmallVectorImpl<CallInst *> &CosCalls, - SmallVectorImpl<CallInst *> &SinCosCalls) { - CallInst *CI = dyn_cast<CallInst>(Val); - - if (!CI) - return; - - Function *Callee = CI->getCalledFunction(); - StringRef FuncName = Callee->getName(); - LibFunc::Func Func; - if (!TLI->getLibFunc(FuncName, Func) || !TLI->has(Func) || - !isTrigLibCall(CI)) - return; - - if (IsFloat) { - if (Func == LibFunc::sinpif) - SinCalls.push_back(CI); - else if (Func == LibFunc::cospif) - CosCalls.push_back(CI); - else if (Func == LibFunc::sincospif_stret) - SinCosCalls.push_back(CI); - } else { - if (Func == LibFunc::sinpi) - SinCalls.push_back(CI); - else if (Func == LibFunc::cospi) - CosCalls.push_back(CI); - else if (Func == LibFunc::sincospi_stret) - SinCosCalls.push_back(CI); - } - } + // It's only worthwhile if both sinpi and cospi are actually used. + if (SinCosCalls.empty() && (SinCalls.empty() || CosCalls.empty())) + return nullptr; - void replaceTrigInsts(SmallVectorImpl<CallInst*> &Calls, Value *Res) { - for (SmallVectorImpl<CallInst*>::iterator I = Calls.begin(), - E = Calls.end(); - I != E; ++I) { - LCS->replaceAllUsesWith(*I, Res); - } - } + Value *Sin, *Cos, *SinCos; + insertSinCosCall(B, CI->getCalledFunction(), Arg, IsFloat, Sin, Cos, SinCos); - void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg, - bool UseFloat, Value *&Sin, Value *&Cos, - Value *&SinCos) { - Type *ArgTy = Arg->getType(); - Type *ResTy; - StringRef Name; - - Triple T(OrigCallee->getParent()->getTargetTriple()); - if (UseFloat) { - Name = "__sincospif_stret"; - - assert(T.getArch() != Triple::x86 && "x86 messy and unsupported for now"); - // x86_64 can't use {float, float} since that would be returned in both - // xmm0 and xmm1, which isn't what a real struct would do. - ResTy = T.getArch() == Triple::x86_64 - ? static_cast<Type *>(VectorType::get(ArgTy, 2)) - : static_cast<Type *>(StructType::get(ArgTy, ArgTy, NULL)); - } else { - Name = "__sincospi_stret"; - ResTy = StructType::get(ArgTy, ArgTy, NULL); - } + replaceTrigInsts(SinCalls, Sin); + replaceTrigInsts(CosCalls, Cos); + replaceTrigInsts(SinCosCalls, SinCos); - Module *M = OrigCallee->getParent(); - Value *Callee = M->getOrInsertFunction(Name, OrigCallee->getAttributes(), - ResTy, ArgTy, NULL); - - if (Instruction *ArgInst = dyn_cast<Instruction>(Arg)) { - // If the argument is an instruction, it must dominate all uses so put our - // sincos call there. - BasicBlock::iterator Loc = ArgInst; - B.SetInsertPoint(ArgInst->getParent(), ++Loc); - } else { - // Otherwise (e.g. for a constant) the beginning of the function is as - // good a place as any. - BasicBlock &EntryBB = B.GetInsertBlock()->getParent()->getEntryBlock(); - B.SetInsertPoint(&EntryBB, EntryBB.begin()); - } + return nullptr; +} - SinCos = B.CreateCall(Callee, Arg, "sincospi"); +static bool isTrigLibCall(CallInst *CI) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + + // We can only hope to do anything useful if we can ignore things like errno + // and floating-point exceptions. + bool AttributesSafe = + CI->hasFnAttr(Attribute::NoUnwind) && CI->hasFnAttr(Attribute::ReadNone); + + // Other than that we need float(float) or double(double) + return AttributesSafe && FT->getNumParams() == 1 && + FT->getReturnType() == FT->getParamType(0) && + (FT->getParamType(0)->isFloatTy() || + FT->getParamType(0)->isDoubleTy()); +} - if (SinCos->getType()->isStructTy()) { - Sin = B.CreateExtractValue(SinCos, 0, "sinpi"); - Cos = B.CreateExtractValue(SinCos, 1, "cospi"); - } else { - Sin = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 0), - "sinpi"); - Cos = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 1), - "cospi"); - } +void +LibCallSimplifier::classifyArgUse(Value *Val, BasicBlock *BB, bool IsFloat, + SmallVectorImpl<CallInst *> &SinCalls, + SmallVectorImpl<CallInst *> &CosCalls, + SmallVectorImpl<CallInst *> &SinCosCalls) { + CallInst *CI = dyn_cast<CallInst>(Val); + + if (!CI) + return; + + Function *Callee = CI->getCalledFunction(); + StringRef FuncName = Callee->getName(); + LibFunc::Func Func; + if (!TLI->getLibFunc(FuncName, Func) || !TLI->has(Func) || !isTrigLibCall(CI)) + return; + + if (IsFloat) { + if (Func == LibFunc::sinpif) + SinCalls.push_back(CI); + else if (Func == LibFunc::cospif) + CosCalls.push_back(CI); + else if (Func == LibFunc::sincospif_stret) + SinCosCalls.push_back(CI); + } else { + if (Func == LibFunc::sinpi) + SinCalls.push_back(CI); + else if (Func == LibFunc::cospi) + CosCalls.push_back(CI); + else if (Func == LibFunc::sincospi_stret) + SinCosCalls.push_back(CI); + } +} + +void LibCallSimplifier::replaceTrigInsts(SmallVectorImpl<CallInst *> &Calls, + Value *Res) { + for (SmallVectorImpl<CallInst *>::iterator I = Calls.begin(), E = Calls.end(); + I != E; ++I) { + replaceAllUsesWith(*I, Res); } +} -}; +void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg, + bool UseFloat, Value *&Sin, Value *&Cos, Value *&SinCos) { + Type *ArgTy = Arg->getType(); + Type *ResTy; + StringRef Name; + + Triple T(OrigCallee->getParent()->getTargetTriple()); + if (UseFloat) { + Name = "__sincospif_stret"; + + assert(T.getArch() != Triple::x86 && "x86 messy and unsupported for now"); + // x86_64 can't use {float, float} since that would be returned in both + // xmm0 and xmm1, which isn't what a real struct would do. + ResTy = T.getArch() == Triple::x86_64 + ? static_cast<Type *>(VectorType::get(ArgTy, 2)) + : static_cast<Type *>(StructType::get(ArgTy, ArgTy, nullptr)); + } else { + Name = "__sincospi_stret"; + ResTy = StructType::get(ArgTy, ArgTy, nullptr); + } + + Module *M = OrigCallee->getParent(); + Value *Callee = M->getOrInsertFunction(Name, OrigCallee->getAttributes(), + ResTy, ArgTy, nullptr); + + if (Instruction *ArgInst = dyn_cast<Instruction>(Arg)) { + // If the argument is an instruction, it must dominate all uses so put our + // sincos call there. + BasicBlock::iterator Loc = ArgInst; + B.SetInsertPoint(ArgInst->getParent(), ++Loc); + } else { + // Otherwise (e.g. for a constant) the beginning of the function is as + // good a place as any. + BasicBlock &EntryBB = B.GetInsertBlock()->getParent()->getEntryBlock(); + B.SetInsertPoint(&EntryBB, EntryBB.begin()); + } + + SinCos = B.CreateCall(Callee, Arg, "sincospi"); + + if (SinCos->getType()->isStructTy()) { + Sin = B.CreateExtractValue(SinCos, 0, "sinpi"); + Cos = B.CreateExtractValue(SinCos, 1, "cospi"); + } else { + Sin = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 0), + "sinpi"); + Cos = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 1), + "cospi"); + } +} //===----------------------------------------------------------------------===// // Integer Library Call Optimizations //===----------------------------------------------------------------------===// -struct FFSOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - // Just make sure this has 2 arguments of the same FP type, which match the - // result type. - if (FT->getNumParams() != 1 || - !FT->getReturnType()->isIntegerTy(32) || - !FT->getParamType(0)->isIntegerTy()) - return nullptr; +Value *LibCallSimplifier::optimizeFFS(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + // Just make sure this has 2 arguments of the same FP type, which match the + // result type. + if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy(32) || + !FT->getParamType(0)->isIntegerTy()) + return nullptr; - Value *Op = CI->getArgOperand(0); + Value *Op = CI->getArgOperand(0); - // Constant fold. - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op)) { - if (CI->isZero()) // ffs(0) -> 0. - return B.getInt32(0); - // ffs(c) -> cttz(c)+1 - return B.getInt32(CI->getValue().countTrailingZeros() + 1); - } + // Constant fold. + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op)) { + if (CI->isZero()) // ffs(0) -> 0. + return B.getInt32(0); + // ffs(c) -> cttz(c)+1 + return B.getInt32(CI->getValue().countTrailingZeros() + 1); + } - // ffs(x) -> x != 0 ? (i32)llvm.cttz(x)+1 : 0 - Type *ArgType = Op->getType(); - Value *F = Intrinsic::getDeclaration(Callee->getParent(), - Intrinsic::cttz, ArgType); - Value *V = B.CreateCall2(F, Op, B.getFalse(), "cttz"); - V = B.CreateAdd(V, ConstantInt::get(V->getType(), 1)); - V = B.CreateIntCast(V, B.getInt32Ty(), false); - - Value *Cond = B.CreateICmpNE(Op, Constant::getNullValue(ArgType)); - return B.CreateSelect(Cond, V, B.getInt32(0)); - } -}; - -struct AbsOpt : public LibCallOptimization { - bool ignoreCallingConv() override { return true; } - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - // We require integer(integer) where the types agree. - if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy() || - FT->getParamType(0) != FT->getReturnType()) - return nullptr; + // ffs(x) -> x != 0 ? (i32)llvm.cttz(x)+1 : 0 + Type *ArgType = Op->getType(); + Value *F = + Intrinsic::getDeclaration(Callee->getParent(), Intrinsic::cttz, ArgType); + Value *V = B.CreateCall2(F, Op, B.getFalse(), "cttz"); + V = B.CreateAdd(V, ConstantInt::get(V->getType(), 1)); + V = B.CreateIntCast(V, B.getInt32Ty(), false); - // abs(x) -> x >s -1 ? x : -x - Value *Op = CI->getArgOperand(0); - Value *Pos = B.CreateICmpSGT(Op, Constant::getAllOnesValue(Op->getType()), - "ispos"); - Value *Neg = B.CreateNeg(Op, "neg"); - return B.CreateSelect(Pos, Op, Neg); - } -}; - -struct IsDigitOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - // We require integer(i32) - if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy() || - !FT->getParamType(0)->isIntegerTy(32)) - return nullptr; + Value *Cond = B.CreateICmpNE(Op, Constant::getNullValue(ArgType)); + return B.CreateSelect(Cond, V, B.getInt32(0)); +} - // isdigit(c) -> (c-'0') <u 10 - Value *Op = CI->getArgOperand(0); - Op = B.CreateSub(Op, B.getInt32('0'), "isdigittmp"); - Op = B.CreateICmpULT(Op, B.getInt32(10), "isdigit"); - return B.CreateZExt(Op, CI->getType()); - } -}; - -struct IsAsciiOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - // We require integer(i32) - if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy() || - !FT->getParamType(0)->isIntegerTy(32)) - return nullptr; +Value *LibCallSimplifier::optimizeAbs(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + // We require integer(integer) where the types agree. + if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy() || + FT->getParamType(0) != FT->getReturnType()) + return nullptr; - // isascii(c) -> c <u 128 - Value *Op = CI->getArgOperand(0); - Op = B.CreateICmpULT(Op, B.getInt32(128), "isascii"); - return B.CreateZExt(Op, CI->getType()); - } -}; + // abs(x) -> x >s -1 ? x : -x + Value *Op = CI->getArgOperand(0); + Value *Pos = + B.CreateICmpSGT(Op, Constant::getAllOnesValue(Op->getType()), "ispos"); + Value *Neg = B.CreateNeg(Op, "neg"); + return B.CreateSelect(Pos, Op, Neg); +} -struct ToAsciiOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - // We require i32(i32) - if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isIntegerTy(32)) - return nullptr; +Value *LibCallSimplifier::optimizeIsDigit(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + // We require integer(i32) + if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy() || + !FT->getParamType(0)->isIntegerTy(32)) + return nullptr; - // toascii(c) -> c & 0x7f - return B.CreateAnd(CI->getArgOperand(0), - ConstantInt::get(CI->getType(),0x7F)); - } -}; + // isdigit(c) -> (c-'0') <u 10 + Value *Op = CI->getArgOperand(0); + Op = B.CreateSub(Op, B.getInt32('0'), "isdigittmp"); + Op = B.CreateICmpULT(Op, B.getInt32(10), "isdigit"); + return B.CreateZExt(Op, CI->getType()); +} + +Value *LibCallSimplifier::optimizeIsAscii(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + // We require integer(i32) + if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy() || + !FT->getParamType(0)->isIntegerTy(32)) + return nullptr; + + // isascii(c) -> c <u 128 + Value *Op = CI->getArgOperand(0); + Op = B.CreateICmpULT(Op, B.getInt32(128), "isascii"); + return B.CreateZExt(Op, CI->getType()); +} + +Value *LibCallSimplifier::optimizeToAscii(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + // We require i32(i32) + if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isIntegerTy(32)) + return nullptr; + + // toascii(c) -> c & 0x7f + return B.CreateAnd(CI->getArgOperand(0), + ConstantInt::get(CI->getType(), 0x7F)); +} //===----------------------------------------------------------------------===// // Formatting and IO Library Call Optimizations //===----------------------------------------------------------------------===// -struct ErrorReportingOpt : public LibCallOptimization { - ErrorReportingOpt(int S = -1) : StreamArg(S) {} +static bool isReportingError(Function *Callee, CallInst *CI, int StreamArg); - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &) override { - // Error reporting calls should be cold, mark them as such. - // This applies even to non-builtin calls: it is only a hint and applies to - // functions that the frontend might not understand as builtins. +Value *LibCallSimplifier::optimizeErrorReporting(CallInst *CI, IRBuilder<> &B, + int StreamArg) { + // Error reporting calls should be cold, mark them as such. + // This applies even to non-builtin calls: it is only a hint and applies to + // functions that the frontend might not understand as builtins. - // This heuristic was suggested in: - // Improving Static Branch Prediction in a Compiler - // Brian L. Deitrich, Ben-Chung Cheng, Wen-mei W. Hwu - // Proceedings of PACT'98, Oct. 1998, IEEE - - if (!CI->hasFnAttr(Attribute::Cold) && isReportingError(Callee, CI)) { - CI->addAttribute(AttributeSet::FunctionIndex, Attribute::Cold); - } + // This heuristic was suggested in: + // Improving Static Branch Prediction in a Compiler + // Brian L. Deitrich, Ben-Chung Cheng, Wen-mei W. Hwu + // Proceedings of PACT'98, Oct. 1998, IEEE + Function *Callee = CI->getCalledFunction(); - return nullptr; + if (!CI->hasFnAttr(Attribute::Cold) && + isReportingError(Callee, CI, StreamArg)) { + CI->addAttribute(AttributeSet::FunctionIndex, Attribute::Cold); } -protected: - bool isReportingError(Function *Callee, CallInst *CI) { - if (!ColdErrorCalls) - return false; - - if (!Callee || !Callee->isDeclaration()) - return false; + return nullptr; +} - if (StreamArg < 0) - return true; +static bool isReportingError(Function *Callee, CallInst *CI, int StreamArg) { + if (!ColdErrorCalls) + return false; - // These functions might be considered cold, but only if their stream - // argument is stderr. + if (!Callee || !Callee->isDeclaration()) + return false; - if (StreamArg >= (int) CI->getNumArgOperands()) - return false; - LoadInst *LI = dyn_cast<LoadInst>(CI->getArgOperand(StreamArg)); - if (!LI) - return false; - GlobalVariable *GV = dyn_cast<GlobalVariable>(LI->getPointerOperand()); - if (!GV || !GV->isDeclaration()) - return false; - return GV->getName() == "stderr"; - } + if (StreamArg < 0) + return true; - int StreamArg; -}; + // These functions might be considered cold, but only if their stream + // argument is stderr. -struct PrintFOpt : public LibCallOptimization { - Value *optimizeFixedFormatString(Function *Callee, CallInst *CI, - IRBuilder<> &B) { - // Check for a fixed format string. - StringRef FormatStr; - if (!getConstantStringInfo(CI->getArgOperand(0), FormatStr)) - return nullptr; + if (StreamArg >= (int)CI->getNumArgOperands()) + return false; + LoadInst *LI = dyn_cast<LoadInst>(CI->getArgOperand(StreamArg)); + if (!LI) + return false; + GlobalVariable *GV = dyn_cast<GlobalVariable>(LI->getPointerOperand()); + if (!GV || !GV->isDeclaration()) + return false; + return GV->getName() == "stderr"; +} - // Empty format string -> noop. - if (FormatStr.empty()) // Tolerate printf's declared void. - return CI->use_empty() ? (Value*)CI : - ConstantInt::get(CI->getType(), 0); +Value *LibCallSimplifier::optimizePrintFString(CallInst *CI, IRBuilder<> &B) { + // Check for a fixed format string. + StringRef FormatStr; + if (!getConstantStringInfo(CI->getArgOperand(0), FormatStr)) + return nullptr; - // Do not do any of the following transformations if the printf return value - // is used, in general the printf return value is not compatible with either - // putchar() or puts(). - if (!CI->use_empty()) - return nullptr; + // Empty format string -> noop. + if (FormatStr.empty()) // Tolerate printf's declared void. + return CI->use_empty() ? (Value *)CI : ConstantInt::get(CI->getType(), 0); - // printf("x") -> putchar('x'), even for '%'. - if (FormatStr.size() == 1) { - Value *Res = EmitPutChar(B.getInt32(FormatStr[0]), B, DL, TLI); - if (CI->use_empty() || !Res) return Res; - return B.CreateIntCast(Res, CI->getType(), true); - } + // Do not do any of the following transformations if the printf return value + // is used, in general the printf return value is not compatible with either + // putchar() or puts(). + if (!CI->use_empty()) + return nullptr; - // printf("foo\n") --> puts("foo") - if (FormatStr[FormatStr.size()-1] == '\n' && - FormatStr.find('%') == StringRef::npos) { // No format characters. - // Create a string literal with no \n on it. We expect the constant merge - // pass to be run after this pass, to merge duplicate strings. - FormatStr = FormatStr.drop_back(); - Value *GV = B.CreateGlobalString(FormatStr, "str"); - Value *NewCI = EmitPutS(GV, B, DL, TLI); - return (CI->use_empty() || !NewCI) ? - NewCI : - ConstantInt::get(CI->getType(), FormatStr.size()+1); - } + // printf("x") -> putchar('x'), even for '%'. + if (FormatStr.size() == 1) { + Value *Res = EmitPutChar(B.getInt32(FormatStr[0]), B, DL, TLI); + if (CI->use_empty() || !Res) + return Res; + return B.CreateIntCast(Res, CI->getType(), true); + } - // Optimize specific format strings. - // printf("%c", chr) --> putchar(chr) - if (FormatStr == "%c" && CI->getNumArgOperands() > 1 && - CI->getArgOperand(1)->getType()->isIntegerTy()) { - Value *Res = EmitPutChar(CI->getArgOperand(1), B, DL, TLI); + // printf("foo\n") --> puts("foo") + if (FormatStr[FormatStr.size() - 1] == '\n' && + FormatStr.find('%') == StringRef::npos) { // No format characters. + // Create a string literal with no \n on it. We expect the constant merge + // pass to be run after this pass, to merge duplicate strings. + FormatStr = FormatStr.drop_back(); + Value *GV = B.CreateGlobalString(FormatStr, "str"); + Value *NewCI = EmitPutS(GV, B, DL, TLI); + return (CI->use_empty() || !NewCI) + ? NewCI + : ConstantInt::get(CI->getType(), FormatStr.size() + 1); + } - if (CI->use_empty() || !Res) return Res; - return B.CreateIntCast(Res, CI->getType(), true); - } + // Optimize specific format strings. + // printf("%c", chr) --> putchar(chr) + if (FormatStr == "%c" && CI->getNumArgOperands() > 1 && + CI->getArgOperand(1)->getType()->isIntegerTy()) { + Value *Res = EmitPutChar(CI->getArgOperand(1), B, DL, TLI); - // printf("%s\n", str) --> puts(str) - if (FormatStr == "%s\n" && CI->getNumArgOperands() > 1 && - CI->getArgOperand(1)->getType()->isPointerTy()) { - return EmitPutS(CI->getArgOperand(1), B, DL, TLI); - } - return nullptr; + if (CI->use_empty() || !Res) + return Res; + return B.CreateIntCast(Res, CI->getType(), true); } - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Require one fixed pointer argument and an integer/void result. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() < 1 || !FT->getParamType(0)->isPointerTy() || - !(FT->getReturnType()->isIntegerTy() || - FT->getReturnType()->isVoidTy())) - return nullptr; + // printf("%s\n", str) --> puts(str) + if (FormatStr == "%s\n" && CI->getNumArgOperands() > 1 && + CI->getArgOperand(1)->getType()->isPointerTy()) { + return EmitPutS(CI->getArgOperand(1), B, DL, TLI); + } + return nullptr; +} - if (Value *V = optimizeFixedFormatString(Callee, CI, B)) { - return V; - } +Value *LibCallSimplifier::optimizePrintF(CallInst *CI, IRBuilder<> &B) { - // printf(format, ...) -> iprintf(format, ...) if no floating point - // arguments. - if (TLI->has(LibFunc::iprintf) && !callHasFloatingPointArgument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); - Constant *IPrintFFn = - M->getOrInsertFunction("iprintf", FT, Callee->getAttributes()); - CallInst *New = cast<CallInst>(CI->clone()); - New->setCalledFunction(IPrintFFn); - B.Insert(New); - return New; - } + Function *Callee = CI->getCalledFunction(); + // Require one fixed pointer argument and an integer/void result. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() < 1 || !FT->getParamType(0)->isPointerTy() || + !(FT->getReturnType()->isIntegerTy() || FT->getReturnType()->isVoidTy())) return nullptr; + + if (Value *V = optimizePrintFString(CI, B)) { + return V; } -}; -struct SPrintFOpt : public LibCallOptimization { - Value *OptimizeFixedFormatString(Function *Callee, CallInst *CI, - IRBuilder<> &B) { - // Check for a fixed format string. - StringRef FormatStr; - if (!getConstantStringInfo(CI->getArgOperand(1), FormatStr)) - return nullptr; + // printf(format, ...) -> iprintf(format, ...) if no floating point + // arguments. + if (TLI->has(LibFunc::iprintf) && !callHasFloatingPointArgument(CI)) { + Module *M = B.GetInsertBlock()->getParent()->getParent(); + Constant *IPrintFFn = + M->getOrInsertFunction("iprintf", FT, Callee->getAttributes()); + CallInst *New = cast<CallInst>(CI->clone()); + New->setCalledFunction(IPrintFFn); + B.Insert(New); + return New; + } + return nullptr; +} - // If we just have a format string (nothing else crazy) transform it. - if (CI->getNumArgOperands() == 2) { - // Make sure there's no % in the constant array. We could try to handle - // %% -> % in the future if we cared. - for (unsigned i = 0, e = FormatStr.size(); i != e; ++i) - if (FormatStr[i] == '%') - return nullptr; // we found a format specifier, bail out. - - // These optimizations require DataLayout. - if (!DL) return nullptr; - - // sprintf(str, fmt) -> llvm.memcpy(str, fmt, strlen(fmt)+1, 1) - B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), - ConstantInt::get(DL->getIntPtrType(*Context), // Copy the - FormatStr.size() + 1), 1); // nul byte. - return ConstantInt::get(CI->getType(), FormatStr.size()); - } +Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI, IRBuilder<> &B) { + // Check for a fixed format string. + StringRef FormatStr; + if (!getConstantStringInfo(CI->getArgOperand(1), FormatStr)) + return nullptr; - // The remaining optimizations require the format string to be "%s" or "%c" - // and have an extra operand. - if (FormatStr.size() != 2 || FormatStr[0] != '%' || - CI->getNumArgOperands() < 3) - return nullptr; + // If we just have a format string (nothing else crazy) transform it. + if (CI->getNumArgOperands() == 2) { + // Make sure there's no % in the constant array. We could try to handle + // %% -> % in the future if we cared. + for (unsigned i = 0, e = FormatStr.size(); i != e; ++i) + if (FormatStr[i] == '%') + return nullptr; // we found a format specifier, bail out. - // Decode the second character of the format string. - if (FormatStr[1] == 'c') { - // sprintf(dst, "%c", chr) --> *(i8*)dst = chr; *((i8*)dst+1) = 0 - if (!CI->getArgOperand(2)->getType()->isIntegerTy()) return nullptr; - Value *V = B.CreateTrunc(CI->getArgOperand(2), B.getInt8Ty(), "char"); - Value *Ptr = CastToCStr(CI->getArgOperand(0), B); - B.CreateStore(V, Ptr); - Ptr = B.CreateGEP(Ptr, B.getInt32(1), "nul"); - B.CreateStore(B.getInt8(0), Ptr); - - return ConstantInt::get(CI->getType(), 1); - } + // These optimizations require DataLayout. + if (!DL) + return nullptr; - if (FormatStr[1] == 's') { - // These optimizations require DataLayout. - if (!DL) return nullptr; + // sprintf(str, fmt) -> llvm.memcpy(str, fmt, strlen(fmt)+1, 1) + B.CreateMemCpy( + CI->getArgOperand(0), CI->getArgOperand(1), + ConstantInt::get(DL->getIntPtrType(CI->getContext()), + FormatStr.size() + 1), + 1); // Copy the null byte. + return ConstantInt::get(CI->getType(), FormatStr.size()); + } - // sprintf(dest, "%s", str) -> llvm.memcpy(dest, str, strlen(str)+1, 1) - if (!CI->getArgOperand(2)->getType()->isPointerTy()) return nullptr; + // The remaining optimizations require the format string to be "%s" or "%c" + // and have an extra operand. + if (FormatStr.size() != 2 || FormatStr[0] != '%' || + CI->getNumArgOperands() < 3) + return nullptr; - Value *Len = EmitStrLen(CI->getArgOperand(2), B, DL, TLI); - if (!Len) - return nullptr; - Value *IncLen = B.CreateAdd(Len, - ConstantInt::get(Len->getType(), 1), - "leninc"); - B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(2), IncLen, 1); + // Decode the second character of the format string. + if (FormatStr[1] == 'c') { + // sprintf(dst, "%c", chr) --> *(i8*)dst = chr; *((i8*)dst+1) = 0 + if (!CI->getArgOperand(2)->getType()->isIntegerTy()) + return nullptr; + Value *V = B.CreateTrunc(CI->getArgOperand(2), B.getInt8Ty(), "char"); + Value *Ptr = CastToCStr(CI->getArgOperand(0), B); + B.CreateStore(V, Ptr); + Ptr = B.CreateGEP(Ptr, B.getInt32(1), "nul"); + B.CreateStore(B.getInt8(0), Ptr); - // The sprintf result is the unincremented number of bytes in the string. - return B.CreateIntCast(Len, CI->getType(), false); - } - return nullptr; + return ConstantInt::get(CI->getType(), 1); } - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Require two fixed pointer arguments and an integer result. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - !FT->getReturnType()->isIntegerTy()) + if (FormatStr[1] == 's') { + // These optimizations require DataLayout. + if (!DL) return nullptr; - if (Value *V = OptimizeFixedFormatString(Callee, CI, B)) { - return V; - } + // sprintf(dest, "%s", str) -> llvm.memcpy(dest, str, strlen(str)+1, 1) + if (!CI->getArgOperand(2)->getType()->isPointerTy()) + return nullptr; - // sprintf(str, format, ...) -> siprintf(str, format, ...) if no floating - // point arguments. - if (TLI->has(LibFunc::siprintf) && !callHasFloatingPointArgument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); - Constant *SIPrintFFn = - M->getOrInsertFunction("siprintf", FT, Callee->getAttributes()); - CallInst *New = cast<CallInst>(CI->clone()); - New->setCalledFunction(SIPrintFFn); - B.Insert(New); - return New; - } - return nullptr; + Value *Len = EmitStrLen(CI->getArgOperand(2), B, DL, TLI); + if (!Len) + return nullptr; + Value *IncLen = + B.CreateAdd(Len, ConstantInt::get(Len->getType(), 1), "leninc"); + B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(2), IncLen, 1); + + // The sprintf result is the unincremented number of bytes in the string. + return B.CreateIntCast(Len, CI->getType(), false); } -}; + return nullptr; +} -struct FPrintFOpt : public LibCallOptimization { - Value *optimizeFixedFormatString(Function *Callee, CallInst *CI, - IRBuilder<> &B) { - ErrorReportingOpt ER(/* StreamArg = */ 0); - (void) ER.callOptimizer(Callee, CI, B); +Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Require two fixed pointer arguments and an integer result. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || + !FT->getReturnType()->isIntegerTy()) + return nullptr; - // All the optimizations depend on the format string. - StringRef FormatStr; - if (!getConstantStringInfo(CI->getArgOperand(1), FormatStr)) - return nullptr; + if (Value *V = optimizeSPrintFString(CI, B)) { + return V; + } - // Do not do any of the following transformations if the fprintf return - // value is used, in general the fprintf return value is not compatible - // with fwrite(), fputc() or fputs(). - if (!CI->use_empty()) - return nullptr; + // sprintf(str, format, ...) -> siprintf(str, format, ...) if no floating + // point arguments. + if (TLI->has(LibFunc::siprintf) && !callHasFloatingPointArgument(CI)) { + Module *M = B.GetInsertBlock()->getParent()->getParent(); + Constant *SIPrintFFn = + M->getOrInsertFunction("siprintf", FT, Callee->getAttributes()); + CallInst *New = cast<CallInst>(CI->clone()); + New->setCalledFunction(SIPrintFFn); + B.Insert(New); + return New; + } + return nullptr; +} + +Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI, IRBuilder<> &B) { + optimizeErrorReporting(CI, B, 0); - // fprintf(F, "foo") --> fwrite("foo", 3, 1, F) - if (CI->getNumArgOperands() == 2) { - for (unsigned i = 0, e = FormatStr.size(); i != e; ++i) - if (FormatStr[i] == '%') // Could handle %% -> % if we cared. - return nullptr; // We found a format specifier. + // All the optimizations depend on the format string. + StringRef FormatStr; + if (!getConstantStringInfo(CI->getArgOperand(1), FormatStr)) + return nullptr; - // These optimizations require DataLayout. - if (!DL) return nullptr; + // Do not do any of the following transformations if the fprintf return + // value is used, in general the fprintf return value is not compatible + // with fwrite(), fputc() or fputs(). + if (!CI->use_empty()) + return nullptr; - return EmitFWrite(CI->getArgOperand(1), - ConstantInt::get(DL->getIntPtrType(*Context), - FormatStr.size()), - CI->getArgOperand(0), B, DL, TLI); - } + // fprintf(F, "foo") --> fwrite("foo", 3, 1, F) + if (CI->getNumArgOperands() == 2) { + for (unsigned i = 0, e = FormatStr.size(); i != e; ++i) + if (FormatStr[i] == '%') // Could handle %% -> % if we cared. + return nullptr; // We found a format specifier. - // The remaining optimizations require the format string to be "%s" or "%c" - // and have an extra operand. - if (FormatStr.size() != 2 || FormatStr[0] != '%' || - CI->getNumArgOperands() < 3) + // These optimizations require DataLayout. + if (!DL) return nullptr; - // Decode the second character of the format string. - if (FormatStr[1] == 'c') { - // fprintf(F, "%c", chr) --> fputc(chr, F) - if (!CI->getArgOperand(2)->getType()->isIntegerTy()) return nullptr; - return EmitFPutC(CI->getArgOperand(2), CI->getArgOperand(0), B, DL, TLI); - } + return EmitFWrite( + CI->getArgOperand(1), + ConstantInt::get(DL->getIntPtrType(CI->getContext()), FormatStr.size()), + CI->getArgOperand(0), B, DL, TLI); + } - if (FormatStr[1] == 's') { - // fprintf(F, "%s", str) --> fputs(str, F) - if (!CI->getArgOperand(2)->getType()->isPointerTy()) - return nullptr; - return EmitFPutS(CI->getArgOperand(2), CI->getArgOperand(0), B, DL, TLI); - } + // The remaining optimizations require the format string to be "%s" or "%c" + // and have an extra operand. + if (FormatStr.size() != 2 || FormatStr[0] != '%' || + CI->getNumArgOperands() < 3) return nullptr; - } - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Require two fixed paramters as pointers and integer result. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - !FT->getReturnType()->isIntegerTy()) + // Decode the second character of the format string. + if (FormatStr[1] == 'c') { + // fprintf(F, "%c", chr) --> fputc(chr, F) + if (!CI->getArgOperand(2)->getType()->isIntegerTy()) return nullptr; + return EmitFPutC(CI->getArgOperand(2), CI->getArgOperand(0), B, DL, TLI); + } - if (Value *V = optimizeFixedFormatString(Callee, CI, B)) { - return V; - } + if (FormatStr[1] == 's') { + // fprintf(F, "%s", str) --> fputs(str, F) + if (!CI->getArgOperand(2)->getType()->isPointerTy()) + return nullptr; + return EmitFPutS(CI->getArgOperand(2), CI->getArgOperand(0), B, DL, TLI); + } + return nullptr; +} - // fprintf(stream, format, ...) -> fiprintf(stream, format, ...) if no - // floating point arguments. - if (TLI->has(LibFunc::fiprintf) && !callHasFloatingPointArgument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); - Constant *FIPrintFFn = - M->getOrInsertFunction("fiprintf", FT, Callee->getAttributes()); - CallInst *New = cast<CallInst>(CI->clone()); - New->setCalledFunction(FIPrintFFn); - B.Insert(New); - return New; - } +Value *LibCallSimplifier::optimizeFPrintF(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Require two fixed paramters as pointers and integer result. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || + !FT->getReturnType()->isIntegerTy()) return nullptr; + + if (Value *V = optimizeFPrintFString(CI, B)) { + return V; } -}; -struct FWriteOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - ErrorReportingOpt ER(/* StreamArg = */ 3); - (void) ER.callOptimizer(Callee, CI, B); + // fprintf(stream, format, ...) -> fiprintf(stream, format, ...) if no + // floating point arguments. + if (TLI->has(LibFunc::fiprintf) && !callHasFloatingPointArgument(CI)) { + Module *M = B.GetInsertBlock()->getParent()->getParent(); + Constant *FIPrintFFn = + M->getOrInsertFunction("fiprintf", FT, Callee->getAttributes()); + CallInst *New = cast<CallInst>(CI->clone()); + New->setCalledFunction(FIPrintFFn); + B.Insert(New); + return New; + } + return nullptr; +} - // Require a pointer, an integer, an integer, a pointer, returning integer. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 4 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isIntegerTy() || - !FT->getParamType(2)->isIntegerTy() || - !FT->getParamType(3)->isPointerTy() || - !FT->getReturnType()->isIntegerTy()) - return nullptr; +Value *LibCallSimplifier::optimizeFWrite(CallInst *CI, IRBuilder<> &B) { + optimizeErrorReporting(CI, B, 3); - // Get the element size and count. - ConstantInt *SizeC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); - ConstantInt *CountC = dyn_cast<ConstantInt>(CI->getArgOperand(2)); - if (!SizeC || !CountC) return nullptr; - uint64_t Bytes = SizeC->getZExtValue()*CountC->getZExtValue(); - - // If this is writing zero records, remove the call (it's a noop). - if (Bytes == 0) - return ConstantInt::get(CI->getType(), 0); - - // If this is writing one byte, turn it into fputc. - // This optimisation is only valid, if the return value is unused. - if (Bytes == 1 && CI->use_empty()) { // fwrite(S,1,1,F) -> fputc(S[0],F) - Value *Char = B.CreateLoad(CastToCStr(CI->getArgOperand(0), B), "char"); - Value *NewCI = EmitFPutC(Char, CI->getArgOperand(3), B, DL, TLI); - return NewCI ? ConstantInt::get(CI->getType(), 1) : nullptr; - } + Function *Callee = CI->getCalledFunction(); + // Require a pointer, an integer, an integer, a pointer, returning integer. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 4 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isIntegerTy() || + !FT->getParamType(2)->isIntegerTy() || + !FT->getParamType(3)->isPointerTy() || + !FT->getReturnType()->isIntegerTy()) + return nullptr; + // Get the element size and count. + ConstantInt *SizeC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + ConstantInt *CountC = dyn_cast<ConstantInt>(CI->getArgOperand(2)); + if (!SizeC || !CountC) return nullptr; - } -}; + uint64_t Bytes = SizeC->getZExtValue() * CountC->getZExtValue(); -struct FPutsOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - ErrorReportingOpt ER(/* StreamArg = */ 1); - (void) ER.callOptimizer(Callee, CI, B); + // If this is writing zero records, remove the call (it's a noop). + if (Bytes == 0) + return ConstantInt::get(CI->getType(), 0); - // These optimizations require DataLayout. - if (!DL) return nullptr; + // If this is writing one byte, turn it into fputc. + // This optimisation is only valid, if the return value is unused. + if (Bytes == 1 && CI->use_empty()) { // fwrite(S,1,1,F) -> fputc(S[0],F) + Value *Char = B.CreateLoad(CastToCStr(CI->getArgOperand(0), B), "char"); + Value *NewCI = EmitFPutC(Char, CI->getArgOperand(3), B, DL, TLI); + return NewCI ? ConstantInt::get(CI->getType(), 1) : nullptr; + } - // Require two pointers. Also, we can't optimize if return value is used. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - !CI->use_empty()) - return nullptr; + return nullptr; +} - // fputs(s,F) --> fwrite(s,1,strlen(s),F) - uint64_t Len = GetStringLength(CI->getArgOperand(0)); - if (!Len) return nullptr; - // Known to have no uses (see above). - return EmitFWrite(CI->getArgOperand(0), - ConstantInt::get(DL->getIntPtrType(*Context), Len-1), - CI->getArgOperand(1), B, DL, TLI); - } -}; - -struct PutsOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Require one fixed pointer argument and an integer/void result. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() < 1 || !FT->getParamType(0)->isPointerTy() || - !(FT->getReturnType()->isIntegerTy() || - FT->getReturnType()->isVoidTy())) - return nullptr; +Value *LibCallSimplifier::optimizeFPuts(CallInst *CI, IRBuilder<> &B) { + optimizeErrorReporting(CI, B, 1); - // Check for a constant string. - StringRef Str; - if (!getConstantStringInfo(CI->getArgOperand(0), Str)) - return nullptr; + Function *Callee = CI->getCalledFunction(); - if (Str.empty() && CI->use_empty()) { - // puts("") -> putchar('\n') - Value *Res = EmitPutChar(B.getInt32('\n'), B, DL, TLI); - if (CI->use_empty() || !Res) return Res; - return B.CreateIntCast(Res, CI->getType(), true); - } + // These optimizations require DataLayout. + if (!DL) + return nullptr; + // Require two pointers. Also, we can't optimize if return value is used. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || !CI->use_empty()) + return nullptr; + + // fputs(s,F) --> fwrite(s,1,strlen(s),F) + uint64_t Len = GetStringLength(CI->getArgOperand(0)); + if (!Len) return nullptr; - } -}; -} // End anonymous namespace. + // Known to have no uses (see above). + return EmitFWrite( + CI->getArgOperand(0), + ConstantInt::get(DL->getIntPtrType(CI->getContext()), Len - 1), + CI->getArgOperand(1), B, DL, TLI); +} -namespace llvm { +Value *LibCallSimplifier::optimizePuts(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Require one fixed pointer argument and an integer/void result. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() < 1 || !FT->getParamType(0)->isPointerTy() || + !(FT->getReturnType()->isIntegerTy() || FT->getReturnType()->isVoidTy())) + return nullptr; -class LibCallSimplifierImpl { - const DataLayout *DL; - const TargetLibraryInfo *TLI; - const LibCallSimplifier *LCS; - bool UnsafeFPShrink; + // Check for a constant string. + StringRef Str; + if (!getConstantStringInfo(CI->getArgOperand(0), Str)) + return nullptr; - // Math library call optimizations. - CosOpt Cos; - PowOpt Pow; - Exp2Opt Exp2; -public: - LibCallSimplifierImpl(const DataLayout *DL, const TargetLibraryInfo *TLI, - const LibCallSimplifier *LCS, - bool UnsafeFPShrink = false) - : Cos(UnsafeFPShrink), Pow(UnsafeFPShrink), Exp2(UnsafeFPShrink) { - this->DL = DL; - this->TLI = TLI; - this->LCS = LCS; - this->UnsafeFPShrink = UnsafeFPShrink; + if (Str.empty() && CI->use_empty()) { + // puts("") -> putchar('\n') + Value *Res = EmitPutChar(B.getInt32('\n'), B, DL, TLI); + if (CI->use_empty() || !Res) + return Res; + return B.CreateIntCast(Res, CI->getType(), true); } - Value *optimizeCall(CallInst *CI); - LibCallOptimization *lookupOptimization(CallInst *CI); - bool hasFloatVersion(StringRef FuncName); -}; + return nullptr; +} -bool LibCallSimplifierImpl::hasFloatVersion(StringRef FuncName) { +bool LibCallSimplifier::hasFloatVersion(StringRef FuncName) { LibFunc::Func Func; SmallString<20> FloatFuncName = FuncName; FloatFuncName += 'f'; @@ -2048,263 +1857,239 @@ bool LibCallSimplifierImpl::hasFloatVersion(StringRef FuncName) { return false; } -// Fortified library call optimizations. -static MemCpyChkOpt MemCpyChk; -static MemMoveChkOpt MemMoveChk; -static MemSetChkOpt MemSetChk; -static StrCpyChkOpt StrCpyChk; -static StpCpyChkOpt StpCpyChk; -static StrNCpyChkOpt StrNCpyChk; - -// String library call optimizations. -static StrCatOpt StrCat; -static StrNCatOpt StrNCat; -static StrChrOpt StrChr; -static StrRChrOpt StrRChr; -static StrCmpOpt StrCmp; -static StrNCmpOpt StrNCmp; -static StrCpyOpt StrCpy; -static StpCpyOpt StpCpy; -static StrNCpyOpt StrNCpy; -static StrLenOpt StrLen; -static StrPBrkOpt StrPBrk; -static StrToOpt StrTo; -static StrSpnOpt StrSpn; -static StrCSpnOpt StrCSpn; -static StrStrOpt StrStr; - -// Memory library call optimizations. -static MemCmpOpt MemCmp; -static MemCpyOpt MemCpy; -static MemMoveOpt MemMove; -static MemSetOpt MemSet; - -// Math library call optimizations. -static UnaryDoubleFPOpt UnaryDoubleFP(false); -static BinaryDoubleFPOpt BinaryDoubleFP(false); -static UnaryDoubleFPOpt UnsafeUnaryDoubleFP(true); -static SinCosPiOpt SinCosPi; - - // Integer library call optimizations. -static FFSOpt FFS; -static AbsOpt Abs; -static IsDigitOpt IsDigit; -static IsAsciiOpt IsAscii; -static ToAsciiOpt ToAscii; - -// Formatting and IO library call optimizations. -static ErrorReportingOpt ErrorReporting; -static ErrorReportingOpt ErrorReporting0(0); -static ErrorReportingOpt ErrorReporting1(1); -static PrintFOpt PrintF; -static SPrintFOpt SPrintF; -static FPrintFOpt FPrintF; -static FWriteOpt FWrite; -static FPutsOpt FPuts; -static PutsOpt Puts; - -LibCallOptimization *LibCallSimplifierImpl::lookupOptimization(CallInst *CI) { +Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, + IRBuilder<> &Builder) { + LibFunc::Func Func; + Function *Callee = CI->getCalledFunction(); + StringRef FuncName = Callee->getName(); + + // Check for string/memory library functions. + if (TLI->getLibFunc(FuncName, Func) && TLI->has(Func)) { + // Make sure we never change the calling convention. + assert((ignoreCallingConv(Func) || + CI->getCallingConv() == llvm::CallingConv::C) && + "Optimizing string/memory libcall would change the calling convention"); + switch (Func) { + case LibFunc::strcat: + return optimizeStrCat(CI, Builder); + case LibFunc::strncat: + return optimizeStrNCat(CI, Builder); + case LibFunc::strchr: + return optimizeStrChr(CI, Builder); + case LibFunc::strrchr: + return optimizeStrRChr(CI, Builder); + case LibFunc::strcmp: + return optimizeStrCmp(CI, Builder); + case LibFunc::strncmp: + return optimizeStrNCmp(CI, Builder); + case LibFunc::strcpy: + return optimizeStrCpy(CI, Builder); + case LibFunc::stpcpy: + return optimizeStpCpy(CI, Builder); + case LibFunc::strncpy: + return optimizeStrNCpy(CI, Builder); + case LibFunc::strlen: + return optimizeStrLen(CI, Builder); + case LibFunc::strpbrk: + return optimizeStrPBrk(CI, Builder); + case LibFunc::strtol: + case LibFunc::strtod: + case LibFunc::strtof: + case LibFunc::strtoul: + case LibFunc::strtoll: + case LibFunc::strtold: + case LibFunc::strtoull: + return optimizeStrTo(CI, Builder); + case LibFunc::strspn: + return optimizeStrSpn(CI, Builder); + case LibFunc::strcspn: + return optimizeStrCSpn(CI, Builder); + case LibFunc::strstr: + return optimizeStrStr(CI, Builder); + case LibFunc::memcmp: + return optimizeMemCmp(CI, Builder); + case LibFunc::memcpy: + return optimizeMemCpy(CI, Builder); + case LibFunc::memmove: + return optimizeMemMove(CI, Builder); + case LibFunc::memset: + return optimizeMemSet(CI, Builder); + default: + break; + } + } + return nullptr; +} + +Value *LibCallSimplifier::optimizeCall(CallInst *CI) { + if (CI->isNoBuiltin()) + return nullptr; + LibFunc::Func Func; Function *Callee = CI->getCalledFunction(); StringRef FuncName = Callee->getName(); + IRBuilder<> Builder(CI); + bool isCallingConvC = CI->getCallingConv() == llvm::CallingConv::C; + + // Command-line parameter overrides function attribute. + if (EnableUnsafeFPShrink.getNumOccurrences() > 0) + UnsafeFPShrink = EnableUnsafeFPShrink; + else if (Callee->hasFnAttribute("unsafe-fp-math")) { + // FIXME: This is the same problem as described in optimizeSqrt(). + // If calls gain access to IR-level FMF, then use that instead of a + // function attribute. + + // Check for unsafe-fp-math = true. + Attribute Attr = Callee->getFnAttribute("unsafe-fp-math"); + if (Attr.getValueAsString() == "true") + UnsafeFPShrink = true; + } - // Next check for intrinsics. + // First, check for intrinsics. if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { + if (!isCallingConvC) + return nullptr; switch (II->getIntrinsicID()) { case Intrinsic::pow: - return &Pow; + return optimizePow(CI, Builder); case Intrinsic::exp2: - return &Exp2; + return optimizeExp2(CI, Builder); + case Intrinsic::fabs: + return optimizeFabs(CI, Builder); + case Intrinsic::sqrt: + return optimizeSqrt(CI, Builder); default: - return nullptr; + return nullptr; } } + // Also try to simplify calls to fortified library functions. + if (Value *SimplifiedFortifiedCI = FortifiedSimplifier.optimizeCall(CI)) { + // Try to further simplify the result. + CallInst *SimplifiedCI = dyn_cast<CallInst>(SimplifiedFortifiedCI); + if (SimplifiedCI && SimplifiedCI->getCalledFunction()) + if (Value *V = optimizeStringMemoryLibCall(SimplifiedCI, Builder)) + return V; + return SimplifiedFortifiedCI; + } + // Then check for known library functions. if (TLI->getLibFunc(FuncName, Func) && TLI->has(Func)) { + // We never change the calling convention. + if (!ignoreCallingConv(Func) && !isCallingConvC) + return nullptr; + if (Value *V = optimizeStringMemoryLibCall(CI, Builder)) + return V; switch (Func) { - case LibFunc::strcat: - return &StrCat; - case LibFunc::strncat: - return &StrNCat; - case LibFunc::strchr: - return &StrChr; - case LibFunc::strrchr: - return &StrRChr; - case LibFunc::strcmp: - return &StrCmp; - case LibFunc::strncmp: - return &StrNCmp; - case LibFunc::strcpy: - return &StrCpy; - case LibFunc::stpcpy: - return &StpCpy; - case LibFunc::strncpy: - return &StrNCpy; - case LibFunc::strlen: - return &StrLen; - case LibFunc::strpbrk: - return &StrPBrk; - case LibFunc::strtol: - case LibFunc::strtod: - case LibFunc::strtof: - case LibFunc::strtoul: - case LibFunc::strtoll: - case LibFunc::strtold: - case LibFunc::strtoull: - return &StrTo; - case LibFunc::strspn: - return &StrSpn; - case LibFunc::strcspn: - return &StrCSpn; - case LibFunc::strstr: - return &StrStr; - case LibFunc::memcmp: - return &MemCmp; - case LibFunc::memcpy: - return &MemCpy; - case LibFunc::memmove: - return &MemMove; - case LibFunc::memset: - return &MemSet; - case LibFunc::cosf: - case LibFunc::cos: - case LibFunc::cosl: - return &Cos; - case LibFunc::sinpif: - case LibFunc::sinpi: - case LibFunc::cospif: - case LibFunc::cospi: - return &SinCosPi; - case LibFunc::powf: - case LibFunc::pow: - case LibFunc::powl: - return &Pow; - case LibFunc::exp2l: - case LibFunc::exp2: - case LibFunc::exp2f: - return &Exp2; - case LibFunc::ffs: - case LibFunc::ffsl: - case LibFunc::ffsll: - return &FFS; - case LibFunc::abs: - case LibFunc::labs: - case LibFunc::llabs: - return &Abs; - case LibFunc::isdigit: - return &IsDigit; - case LibFunc::isascii: - return &IsAscii; - case LibFunc::toascii: - return &ToAscii; - case LibFunc::printf: - return &PrintF; - case LibFunc::sprintf: - return &SPrintF; - case LibFunc::fprintf: - return &FPrintF; - case LibFunc::fwrite: - return &FWrite; - case LibFunc::fputs: - return &FPuts; - case LibFunc::puts: - return &Puts; - case LibFunc::perror: - return &ErrorReporting; - case LibFunc::vfprintf: - case LibFunc::fiprintf: - return &ErrorReporting0; - case LibFunc::fputc: - return &ErrorReporting1; - case LibFunc::ceil: - case LibFunc::fabs: - case LibFunc::floor: - case LibFunc::rint: - case LibFunc::round: - case LibFunc::nearbyint: - case LibFunc::trunc: - if (hasFloatVersion(FuncName)) - return &UnaryDoubleFP; - return nullptr; - case LibFunc::acos: - case LibFunc::acosh: - case LibFunc::asin: - case LibFunc::asinh: - case LibFunc::atan: - case LibFunc::atanh: - case LibFunc::cbrt: - case LibFunc::cosh: - case LibFunc::exp: - case LibFunc::exp10: - case LibFunc::expm1: - case LibFunc::log: - case LibFunc::log10: - case LibFunc::log1p: - case LibFunc::log2: - case LibFunc::logb: - case LibFunc::sin: - case LibFunc::sinh: - case LibFunc::sqrt: - case LibFunc::tan: - case LibFunc::tanh: - if (UnsafeFPShrink && hasFloatVersion(FuncName)) - return &UnsafeUnaryDoubleFP; - return nullptr; - case LibFunc::fmin: - case LibFunc::fmax: - if (hasFloatVersion(FuncName)) - return &BinaryDoubleFP; - return nullptr; - case LibFunc::memcpy_chk: - return &MemCpyChk; - default: - return nullptr; - } - } - - // Finally check for fortified library calls. - if (FuncName.endswith("_chk")) { - if (FuncName == "__memmove_chk") - return &MemMoveChk; - else if (FuncName == "__memset_chk") - return &MemSetChk; - else if (FuncName == "__strcpy_chk") - return &StrCpyChk; - else if (FuncName == "__stpcpy_chk") - return &StpCpyChk; - else if (FuncName == "__strncpy_chk") - return &StrNCpyChk; - else if (FuncName == "__stpncpy_chk") - return &StrNCpyChk; - } - - return nullptr; - -} - -Value *LibCallSimplifierImpl::optimizeCall(CallInst *CI) { - LibCallOptimization *LCO = lookupOptimization(CI); - if (LCO) { - IRBuilder<> Builder(CI); - return LCO->optimizeCall(CI, DL, TLI, LCS, Builder); + case LibFunc::cosf: + case LibFunc::cos: + case LibFunc::cosl: + return optimizeCos(CI, Builder); + case LibFunc::sinpif: + case LibFunc::sinpi: + case LibFunc::cospif: + case LibFunc::cospi: + return optimizeSinCosPi(CI, Builder); + case LibFunc::powf: + case LibFunc::pow: + case LibFunc::powl: + return optimizePow(CI, Builder); + case LibFunc::exp2l: + case LibFunc::exp2: + case LibFunc::exp2f: + return optimizeExp2(CI, Builder); + case LibFunc::fabsf: + case LibFunc::fabs: + case LibFunc::fabsl: + return optimizeFabs(CI, Builder); + case LibFunc::sqrtf: + case LibFunc::sqrt: + case LibFunc::sqrtl: + return optimizeSqrt(CI, Builder); + case LibFunc::ffs: + case LibFunc::ffsl: + case LibFunc::ffsll: + return optimizeFFS(CI, Builder); + case LibFunc::abs: + case LibFunc::labs: + case LibFunc::llabs: + return optimizeAbs(CI, Builder); + case LibFunc::isdigit: + return optimizeIsDigit(CI, Builder); + case LibFunc::isascii: + return optimizeIsAscii(CI, Builder); + case LibFunc::toascii: + return optimizeToAscii(CI, Builder); + case LibFunc::printf: + return optimizePrintF(CI, Builder); + case LibFunc::sprintf: + return optimizeSPrintF(CI, Builder); + case LibFunc::fprintf: + return optimizeFPrintF(CI, Builder); + case LibFunc::fwrite: + return optimizeFWrite(CI, Builder); + case LibFunc::fputs: + return optimizeFPuts(CI, Builder); + case LibFunc::puts: + return optimizePuts(CI, Builder); + case LibFunc::perror: + return optimizeErrorReporting(CI, Builder); + case LibFunc::vfprintf: + case LibFunc::fiprintf: + return optimizeErrorReporting(CI, Builder, 0); + case LibFunc::fputc: + return optimizeErrorReporting(CI, Builder, 1); + case LibFunc::ceil: + case LibFunc::floor: + case LibFunc::rint: + case LibFunc::round: + case LibFunc::nearbyint: + case LibFunc::trunc: + if (hasFloatVersion(FuncName)) + return optimizeUnaryDoubleFP(CI, Builder, false); + return nullptr; + case LibFunc::acos: + case LibFunc::acosh: + case LibFunc::asin: + case LibFunc::asinh: + case LibFunc::atan: + case LibFunc::atanh: + case LibFunc::cbrt: + case LibFunc::cosh: + case LibFunc::exp: + case LibFunc::exp10: + case LibFunc::expm1: + case LibFunc::log: + case LibFunc::log10: + case LibFunc::log1p: + case LibFunc::log2: + case LibFunc::logb: + case LibFunc::sin: + case LibFunc::sinh: + case LibFunc::tan: + case LibFunc::tanh: + if (UnsafeFPShrink && hasFloatVersion(FuncName)) + return optimizeUnaryDoubleFP(CI, Builder, true); + return nullptr; + case LibFunc::copysign: + case LibFunc::fmin: + case LibFunc::fmax: + if (hasFloatVersion(FuncName)) + return optimizeBinaryDoubleFP(CI, Builder); + return nullptr; + default: + return nullptr; + } } return nullptr; } LibCallSimplifier::LibCallSimplifier(const DataLayout *DL, - const TargetLibraryInfo *TLI, - bool UnsafeFPShrink) { - Impl = new LibCallSimplifierImpl(DL, TLI, this, UnsafeFPShrink); -} - -LibCallSimplifier::~LibCallSimplifier() { - delete Impl; -} - -Value *LibCallSimplifier::optimizeCall(CallInst *CI) { - if (CI->isNoBuiltin()) return nullptr; - return Impl->optimizeCall(CI); + const TargetLibraryInfo *TLI) : + FortifiedSimplifier(DL, TLI), + DL(DL), + TLI(TLI), + UnsafeFPShrink(false) { } void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) const { @@ -2312,8 +2097,6 @@ void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) const { I->eraseFromParent(); } -} - // TODO: // Additional cases that we need to add to this file: // @@ -2361,3 +2144,184 @@ void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) const { // * trunc(cnst) -> cnst' // // + +//===----------------------------------------------------------------------===// +// Fortified Library Call Optimizations +//===----------------------------------------------------------------------===// + +bool FortifiedLibCallSimplifier::isFortifiedCallFoldable(CallInst *CI, + unsigned ObjSizeOp, + unsigned SizeOp, + bool isString) { + if (CI->getArgOperand(ObjSizeOp) == CI->getArgOperand(SizeOp)) + return true; + if (ConstantInt *ObjSizeCI = + dyn_cast<ConstantInt>(CI->getArgOperand(ObjSizeOp))) { + if (ObjSizeCI->isAllOnesValue()) + return true; + // If the object size wasn't -1 (unknown), bail out if we were asked to. + if (OnlyLowerUnknownSize) + return false; + if (isString) { + uint64_t Len = GetStringLength(CI->getArgOperand(SizeOp)); + // If the length is 0 we don't know how long it is and so we can't + // remove the check. + if (Len == 0) + return false; + return ObjSizeCI->getZExtValue() >= Len; + } + if (ConstantInt *SizeCI = dyn_cast<ConstantInt>(CI->getArgOperand(SizeOp))) + return ObjSizeCI->getZExtValue() >= SizeCI->getZExtValue(); + } + return false; +} + +Value *FortifiedLibCallSimplifier::optimizeMemCpyChk(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + + if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memcpy_chk, DL)) + return nullptr; + + if (isFortifiedCallFoldable(CI, 3, 2, false)) { + B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), 1); + return CI->getArgOperand(0); + } + return nullptr; +} + +Value *FortifiedLibCallSimplifier::optimizeMemMoveChk(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + + if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memmove_chk, DL)) + return nullptr; + + if (isFortifiedCallFoldable(CI, 3, 2, false)) { + B.CreateMemMove(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), 1); + return CI->getArgOperand(0); + } + return nullptr; +} + +Value *FortifiedLibCallSimplifier::optimizeMemSetChk(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + + if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memset_chk, DL)) + return nullptr; + + if (isFortifiedCallFoldable(CI, 3, 2, false)) { + Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); + B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); + return CI->getArgOperand(0); + } + return nullptr; +} + +Value *FortifiedLibCallSimplifier::optimizeStrCpyChk(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + StringRef Name = Callee->getName(); + LibFunc::Func Func = + Name.startswith("str") ? LibFunc::strcpy_chk : LibFunc::stpcpy_chk; + + if (!checkStringCopyLibFuncSignature(Callee, Func, DL)) + return nullptr; + + Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1), + *ObjSize = CI->getArgOperand(2); + + // __stpcpy_chk(x,x,...) -> x+strlen(x) + if (!OnlyLowerUnknownSize && Dst == Src) { + Value *StrLen = EmitStrLen(Src, B, DL, TLI); + return StrLen ? B.CreateInBoundsGEP(Dst, StrLen) : nullptr; + } + + // If a) we don't have any length information, or b) we know this will + // fit then just lower to a plain st[rp]cpy. Otherwise we'll keep our + // st[rp]cpy_chk call which may fail at runtime if the size is too long. + // TODO: It might be nice to get a maximum length out of the possible + // string lengths for varying. + if (isFortifiedCallFoldable(CI, 2, 1, true)) { + Value *Ret = EmitStrCpy(Dst, Src, B, DL, TLI, Name.substr(2, 6)); + return Ret; + } else if (!OnlyLowerUnknownSize) { + // Maybe we can stil fold __st[rp]cpy_chk to __memcpy_chk. + uint64_t Len = GetStringLength(Src); + if (Len == 0) + return nullptr; + + // This optimization requires DataLayout. + if (!DL) + return nullptr; + + Type *SizeTTy = DL->getIntPtrType(CI->getContext()); + Value *LenV = ConstantInt::get(SizeTTy, Len); + Value *Ret = EmitMemCpyChk(Dst, Src, LenV, ObjSize, B, DL, TLI); + // If the function was an __stpcpy_chk, and we were able to fold it into + // a __memcpy_chk, we still need to return the correct end pointer. + if (Ret && Func == LibFunc::stpcpy_chk) + return B.CreateGEP(Dst, ConstantInt::get(SizeTTy, Len - 1)); + return Ret; + } + return nullptr; +} + +Value *FortifiedLibCallSimplifier::optimizeStrNCpyChk(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + StringRef Name = Callee->getName(); + LibFunc::Func Func = + Name.startswith("str") ? LibFunc::strncpy_chk : LibFunc::stpncpy_chk; + + if (!checkStringCopyLibFuncSignature(Callee, Func, DL)) + return nullptr; + if (isFortifiedCallFoldable(CI, 3, 2, false)) { + Value *Ret = + EmitStrNCpy(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), B, DL, TLI, Name.substr(2, 7)); + return Ret; + } + return nullptr; +} + +Value *FortifiedLibCallSimplifier::optimizeCall(CallInst *CI) { + if (CI->isNoBuiltin()) + return nullptr; + + LibFunc::Func Func; + Function *Callee = CI->getCalledFunction(); + StringRef FuncName = Callee->getName(); + IRBuilder<> Builder(CI); + bool isCallingConvC = CI->getCallingConv() == llvm::CallingConv::C; + + // First, check that this is a known library functions. + if (!TLI->getLibFunc(FuncName, Func) || !TLI->has(Func)) + return nullptr; + + // We never change the calling convention. + if (!ignoreCallingConv(Func) && !isCallingConvC) + return nullptr; + + switch (Func) { + case LibFunc::memcpy_chk: + return optimizeMemCpyChk(CI, Builder); + case LibFunc::memmove_chk: + return optimizeMemMoveChk(CI, Builder); + case LibFunc::memset_chk: + return optimizeMemSetChk(CI, Builder); + case LibFunc::stpcpy_chk: + case LibFunc::strcpy_chk: + return optimizeStrCpyChk(CI, Builder); + case LibFunc::stpncpy_chk: + case LibFunc::strncpy_chk: + return optimizeStrNCpyChk(CI, Builder); + default: + break; + } + return nullptr; +} + +FortifiedLibCallSimplifier:: +FortifiedLibCallSimplifier(const DataLayout *DL, const TargetLibraryInfo *TLI, + bool OnlyLowerUnknownSize) + : DL(DL), TLI(TLI), OnlyLowerUnknownSize(OnlyLowerUnknownSize) { +} diff --git a/lib/Transforms/Utils/SymbolRewriter.cpp b/lib/Transforms/Utils/SymbolRewriter.cpp new file mode 100644 index 000000000000..b35a662f17b5 --- /dev/null +++ b/lib/Transforms/Utils/SymbolRewriter.cpp @@ -0,0 +1,527 @@ +//===- SymbolRewriter.cpp - Symbol Rewriter ---------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// SymbolRewriter is a LLVM pass which can rewrite symbols transparently within +// existing code. It is implemented as a compiler pass and is configured via a +// YAML configuration file. +// +// The YAML configuration file format is as follows: +// +// RewriteMapFile := RewriteDescriptors +// RewriteDescriptors := RewriteDescriptor | RewriteDescriptors +// RewriteDescriptor := RewriteDescriptorType ':' '{' RewriteDescriptorFields '}' +// RewriteDescriptorFields := RewriteDescriptorField | RewriteDescriptorFields +// RewriteDescriptorField := FieldIdentifier ':' FieldValue ',' +// RewriteDescriptorType := Identifier +// FieldIdentifier := Identifier +// FieldValue := Identifier +// Identifier := [0-9a-zA-Z]+ +// +// Currently, the following descriptor types are supported: +// +// - function: (function rewriting) +// + Source (original name of the function) +// + Target (explicit transformation) +// + Transform (pattern transformation) +// + Naked (boolean, whether the function is undecorated) +// - global variable: (external linkage global variable rewriting) +// + Source (original name of externally visible variable) +// + Target (explicit transformation) +// + Transform (pattern transformation) +// - global alias: (global alias rewriting) +// + Source (original name of the aliased name) +// + Target (explicit transformation) +// + Transform (pattern transformation) +// +// Note that source and exactly one of [Target, Transform] must be provided +// +// New rewrite descriptors can be created. Addding a new rewrite descriptor +// involves: +// +// a) extended the rewrite descriptor kind enumeration +// (<anonymous>::RewriteDescriptor::RewriteDescriptorType) +// b) implementing the new descriptor +// (c.f. <anonymous>::ExplicitRewriteFunctionDescriptor) +// c) extending the rewrite map parser +// (<anonymous>::RewriteMapParser::parseEntry) +// +// Specify to rewrite the symbols using the `-rewrite-symbols` option, and +// specify the map file to use for the rewriting via the `-rewrite-map-file` +// option. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "symbol-rewriter" +#include "llvm/CodeGen/Passes.h" +#include "llvm/Pass.h" +#include "llvm/PassManager.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Regex.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/YAMLParser.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO/PassManagerBuilder.h" +#include "llvm/Transforms/Utils/SymbolRewriter.h" + +using namespace llvm; + +static cl::list<std::string> RewriteMapFiles("rewrite-map-file", + cl::desc("Symbol Rewrite Map"), + cl::value_desc("filename")); + +namespace llvm { +namespace SymbolRewriter { +template <RewriteDescriptor::Type DT, typename ValueType, + ValueType *(llvm::Module::*Get)(StringRef) const> +class ExplicitRewriteDescriptor : public RewriteDescriptor { +public: + const std::string Source; + const std::string Target; + + ExplicitRewriteDescriptor(StringRef S, StringRef T, const bool Naked) + : RewriteDescriptor(DT), Source(Naked ? StringRef("\01" + S.str()) : S), + Target(T) {} + + bool performOnModule(Module &M) override; + + static bool classof(const RewriteDescriptor *RD) { + return RD->getType() == DT; + } +}; + +template <RewriteDescriptor::Type DT, typename ValueType, + ValueType *(llvm::Module::*Get)(StringRef) const> +bool ExplicitRewriteDescriptor<DT, ValueType, Get>::performOnModule(Module &M) { + bool Changed = false; + if (ValueType *S = (M.*Get)(Source)) { + if (Value *T = (M.*Get)(Target)) + S->setValueName(T->getValueName()); + else + S->setName(Target); + Changed = true; + } + return Changed; +} + +template <RewriteDescriptor::Type DT, typename ValueType, + ValueType *(llvm::Module::*Get)(StringRef) const, + iterator_range<typename iplist<ValueType>::iterator> + (llvm::Module::*Iterator)()> +class PatternRewriteDescriptor : public RewriteDescriptor { +public: + const std::string Pattern; + const std::string Transform; + + PatternRewriteDescriptor(StringRef P, StringRef T) + : RewriteDescriptor(DT), Pattern(P), Transform(T) { } + + bool performOnModule(Module &M) override; + + static bool classof(const RewriteDescriptor *RD) { + return RD->getType() == DT; + } +}; + +template <RewriteDescriptor::Type DT, typename ValueType, + ValueType *(llvm::Module::*Get)(StringRef) const, + iterator_range<typename iplist<ValueType>::iterator> + (llvm::Module::*Iterator)()> +bool PatternRewriteDescriptor<DT, ValueType, Get, Iterator>:: +performOnModule(Module &M) { + bool Changed = false; + for (auto &C : (M.*Iterator)()) { + std::string Error; + + std::string Name = Regex(Pattern).sub(Transform, C.getName(), &Error); + if (!Error.empty()) + report_fatal_error("unable to transforn " + C.getName() + " in " + + M.getModuleIdentifier() + ": " + Error); + + if (Value *V = (M.*Get)(Name)) + C.setValueName(V->getValueName()); + else + C.setName(Name); + + Changed = true; + } + return Changed; +} + +/// Represents a rewrite for an explicitly named (function) symbol. Both the +/// source function name and target function name of the transformation are +/// explicitly spelt out. +typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::Function, + llvm::Function, &llvm::Module::getFunction> + ExplicitRewriteFunctionDescriptor; + +/// Represents a rewrite for an explicitly named (global variable) symbol. Both +/// the source variable name and target variable name are spelt out. This +/// applies only to module level variables. +typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable, + llvm::GlobalVariable, + &llvm::Module::getGlobalVariable> + ExplicitRewriteGlobalVariableDescriptor; + +/// Represents a rewrite for an explicitly named global alias. Both the source +/// and target name are explicitly spelt out. +typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::NamedAlias, + llvm::GlobalAlias, + &llvm::Module::getNamedAlias> + ExplicitRewriteNamedAliasDescriptor; + +/// Represents a rewrite for a regular expression based pattern for functions. +/// A pattern for the function name is provided and a transformation for that +/// pattern to determine the target function name create the rewrite rule. +typedef PatternRewriteDescriptor<RewriteDescriptor::Type::Function, + llvm::Function, &llvm::Module::getFunction, + &llvm::Module::functions> + PatternRewriteFunctionDescriptor; + +/// Represents a rewrite for a global variable based upon a matching pattern. +/// Each global variable matching the provided pattern will be transformed as +/// described in the transformation pattern for the target. Applies only to +/// module level variables. +typedef PatternRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable, + llvm::GlobalVariable, + &llvm::Module::getGlobalVariable, + &llvm::Module::globals> + PatternRewriteGlobalVariableDescriptor; + +/// PatternRewriteNamedAliasDescriptor - represents a rewrite for global +/// aliases which match a given pattern. The provided transformation will be +/// applied to each of the matching names. +typedef PatternRewriteDescriptor<RewriteDescriptor::Type::NamedAlias, + llvm::GlobalAlias, + &llvm::Module::getNamedAlias, + &llvm::Module::aliases> + PatternRewriteNamedAliasDescriptor; + +bool RewriteMapParser::parse(const std::string &MapFile, + RewriteDescriptorList *DL) { + ErrorOr<std::unique_ptr<MemoryBuffer>> Mapping = + MemoryBuffer::getFile(MapFile); + + if (!Mapping) + report_fatal_error("unable to read rewrite map '" + MapFile + "': " + + Mapping.getError().message()); + + if (!parse(*Mapping, DL)) + report_fatal_error("unable to parse rewrite map '" + MapFile + "'"); + + return true; +} + +bool RewriteMapParser::parse(std::unique_ptr<MemoryBuffer> &MapFile, + RewriteDescriptorList *DL) { + SourceMgr SM; + yaml::Stream YS(MapFile->getBuffer(), SM); + + for (auto &Document : YS) { + yaml::MappingNode *DescriptorList; + + // ignore empty documents + if (isa<yaml::NullNode>(Document.getRoot())) + continue; + + DescriptorList = dyn_cast<yaml::MappingNode>(Document.getRoot()); + if (!DescriptorList) { + YS.printError(Document.getRoot(), "DescriptorList node must be a map"); + return false; + } + + for (auto &Descriptor : *DescriptorList) + if (!parseEntry(YS, Descriptor, DL)) + return false; + } + + return true; +} + +bool RewriteMapParser::parseEntry(yaml::Stream &YS, yaml::KeyValueNode &Entry, + RewriteDescriptorList *DL) { + yaml::ScalarNode *Key; + yaml::MappingNode *Value; + SmallString<32> KeyStorage; + StringRef RewriteType; + + Key = dyn_cast<yaml::ScalarNode>(Entry.getKey()); + if (!Key) { + YS.printError(Entry.getKey(), "rewrite type must be a scalar"); + return false; + } + + Value = dyn_cast<yaml::MappingNode>(Entry.getValue()); + if (!Value) { + YS.printError(Entry.getValue(), "rewrite descriptor must be a map"); + return false; + } + + RewriteType = Key->getValue(KeyStorage); + if (RewriteType.equals("function")) + return parseRewriteFunctionDescriptor(YS, Key, Value, DL); + else if (RewriteType.equals("global variable")) + return parseRewriteGlobalVariableDescriptor(YS, Key, Value, DL); + else if (RewriteType.equals("global alias")) + return parseRewriteGlobalAliasDescriptor(YS, Key, Value, DL); + + YS.printError(Entry.getKey(), "unknown rewrite type"); + return false; +} + +bool RewriteMapParser:: +parseRewriteFunctionDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, + yaml::MappingNode *Descriptor, + RewriteDescriptorList *DL) { + bool Naked = false; + std::string Source; + std::string Target; + std::string Transform; + + for (auto &Field : *Descriptor) { + yaml::ScalarNode *Key; + yaml::ScalarNode *Value; + SmallString<32> KeyStorage; + SmallString<32> ValueStorage; + StringRef KeyValue; + + Key = dyn_cast<yaml::ScalarNode>(Field.getKey()); + if (!Key) { + YS.printError(Field.getKey(), "descriptor key must be a scalar"); + return false; + } + + Value = dyn_cast<yaml::ScalarNode>(Field.getValue()); + if (!Value) { + YS.printError(Field.getValue(), "descriptor value must be a scalar"); + return false; + } + + KeyValue = Key->getValue(KeyStorage); + if (KeyValue.equals("source")) { + std::string Error; + + Source = Value->getValue(ValueStorage); + if (!Regex(Source).isValid(Error)) { + YS.printError(Field.getKey(), "invalid regex: " + Error); + return false; + } + } else if (KeyValue.equals("target")) { + Target = Value->getValue(ValueStorage); + } else if (KeyValue.equals("transform")) { + Transform = Value->getValue(ValueStorage); + } else if (KeyValue.equals("naked")) { + std::string Undecorated; + + Undecorated = Value->getValue(ValueStorage); + Naked = StringRef(Undecorated).lower() == "true" || Undecorated == "1"; + } else { + YS.printError(Field.getKey(), "unknown key for function"); + return false; + } + } + + if (Transform.empty() == Target.empty()) { + YS.printError(Descriptor, + "exactly one of transform or target must be specified"); + return false; + } + + // TODO see if there is a more elegant solution to selecting the rewrite + // descriptor type + if (!Target.empty()) + DL->push_back(new ExplicitRewriteFunctionDescriptor(Source, Target, Naked)); + else + DL->push_back(new PatternRewriteFunctionDescriptor(Source, Transform)); + + return true; +} + +bool RewriteMapParser:: +parseRewriteGlobalVariableDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, + yaml::MappingNode *Descriptor, + RewriteDescriptorList *DL) { + std::string Source; + std::string Target; + std::string Transform; + + for (auto &Field : *Descriptor) { + yaml::ScalarNode *Key; + yaml::ScalarNode *Value; + SmallString<32> KeyStorage; + SmallString<32> ValueStorage; + StringRef KeyValue; + + Key = dyn_cast<yaml::ScalarNode>(Field.getKey()); + if (!Key) { + YS.printError(Field.getKey(), "descriptor Key must be a scalar"); + return false; + } + + Value = dyn_cast<yaml::ScalarNode>(Field.getValue()); + if (!Value) { + YS.printError(Field.getValue(), "descriptor value must be a scalar"); + return false; + } + + KeyValue = Key->getValue(KeyStorage); + if (KeyValue.equals("source")) { + std::string Error; + + Source = Value->getValue(ValueStorage); + if (!Regex(Source).isValid(Error)) { + YS.printError(Field.getKey(), "invalid regex: " + Error); + return false; + } + } else if (KeyValue.equals("target")) { + Target = Value->getValue(ValueStorage); + } else if (KeyValue.equals("transform")) { + Transform = Value->getValue(ValueStorage); + } else { + YS.printError(Field.getKey(), "unknown Key for Global Variable"); + return false; + } + } + + if (Transform.empty() == Target.empty()) { + YS.printError(Descriptor, + "exactly one of transform or target must be specified"); + return false; + } + + if (!Target.empty()) + DL->push_back(new ExplicitRewriteGlobalVariableDescriptor(Source, Target, + /*Naked*/false)); + else + DL->push_back(new PatternRewriteGlobalVariableDescriptor(Source, + Transform)); + + return true; +} + +bool RewriteMapParser:: +parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, + yaml::MappingNode *Descriptor, + RewriteDescriptorList *DL) { + std::string Source; + std::string Target; + std::string Transform; + + for (auto &Field : *Descriptor) { + yaml::ScalarNode *Key; + yaml::ScalarNode *Value; + SmallString<32> KeyStorage; + SmallString<32> ValueStorage; + StringRef KeyValue; + + Key = dyn_cast<yaml::ScalarNode>(Field.getKey()); + if (!Key) { + YS.printError(Field.getKey(), "descriptor key must be a scalar"); + return false; + } + + Value = dyn_cast<yaml::ScalarNode>(Field.getValue()); + if (!Value) { + YS.printError(Field.getValue(), "descriptor value must be a scalar"); + return false; + } + + KeyValue = Key->getValue(KeyStorage); + if (KeyValue.equals("source")) { + std::string Error; + + Source = Value->getValue(ValueStorage); + if (!Regex(Source).isValid(Error)) { + YS.printError(Field.getKey(), "invalid regex: " + Error); + return false; + } + } else if (KeyValue.equals("target")) { + Target = Value->getValue(ValueStorage); + } else if (KeyValue.equals("transform")) { + Transform = Value->getValue(ValueStorage); + } else { + YS.printError(Field.getKey(), "unknown key for Global Alias"); + return false; + } + } + + if (Transform.empty() == Target.empty()) { + YS.printError(Descriptor, + "exactly one of transform or target must be specified"); + return false; + } + + if (!Target.empty()) + DL->push_back(new ExplicitRewriteNamedAliasDescriptor(Source, Target, + /*Naked*/false)); + else + DL->push_back(new PatternRewriteNamedAliasDescriptor(Source, Transform)); + + return true; +} +} +} + +namespace { +class RewriteSymbols : public ModulePass { +public: + static char ID; // Pass identification, replacement for typeid + + RewriteSymbols(); + RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL); + + bool runOnModule(Module &M) override; + +private: + void loadAndParseMapFiles(); + + SymbolRewriter::RewriteDescriptorList Descriptors; +}; + +char RewriteSymbols::ID = 0; + +RewriteSymbols::RewriteSymbols() : ModulePass(ID) { + initializeRewriteSymbolsPass(*PassRegistry::getPassRegistry()); + loadAndParseMapFiles(); +} + +RewriteSymbols::RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL) + : ModulePass(ID) { + Descriptors.splice(Descriptors.begin(), DL); +} + +bool RewriteSymbols::runOnModule(Module &M) { + bool Changed; + + Changed = false; + for (auto &Descriptor : Descriptors) + Changed |= Descriptor.performOnModule(M); + + return Changed; +} + +void RewriteSymbols::loadAndParseMapFiles() { + const std::vector<std::string> MapFiles(RewriteMapFiles); + SymbolRewriter::RewriteMapParser parser; + + for (const auto &MapFile : MapFiles) + parser.parse(MapFile, &Descriptors); +} +} + +INITIALIZE_PASS(RewriteSymbols, "rewrite-symbols", "Rewrite Symbols", false, + false) + +ModulePass *llvm::createRewriteSymbolsPass() { return new RewriteSymbols(); } + +ModulePass * +llvm::createRewriteSymbolsPass(SymbolRewriter::RewriteDescriptorList &DL) { + return new RewriteSymbols(DL); +} diff --git a/lib/Transforms/Utils/ValueMapper.cpp b/lib/Transforms/Utils/ValueMapper.cpp index 0f20e6df6c96..477fba42412e 100644 --- a/lib/Transforms/Utils/ValueMapper.cpp +++ b/lib/Transforms/Utils/ValueMapper.cpp @@ -40,7 +40,7 @@ Value *llvm::MapValue(const Value *V, ValueToValueMapTy &VM, RemapFlags Flags, // Global values do not need to be seeded into the VM if they // are using the identity mapping. - if (isa<GlobalValue>(V) || isa<MDString>(V)) + if (isa<GlobalValue>(V)) return VM[V] = const_cast<Value*>(V); if (const InlineAsm *IA = dyn_cast<InlineAsm>(V)) { @@ -56,57 +56,24 @@ Value *llvm::MapValue(const Value *V, ValueToValueMapTy &VM, RemapFlags Flags, return VM[V] = const_cast<Value*>(V); } - - if (const MDNode *MD = dyn_cast<MDNode>(V)) { + if (const auto *MDV = dyn_cast<MetadataAsValue>(V)) { + const Metadata *MD = MDV->getMetadata(); // If this is a module-level metadata and we know that nothing at the module // level is changing, then use an identity mapping. - if (!MD->isFunctionLocal() && (Flags & RF_NoModuleLevelChanges)) - return VM[V] = const_cast<Value*>(V); - - // Create a dummy node in case we have a metadata cycle. - MDNode *Dummy = MDNode::getTemporary(V->getContext(), None); - VM[V] = Dummy; - - // Check all operands to see if any need to be remapped. - for (unsigned i = 0, e = MD->getNumOperands(); i != e; ++i) { - Value *OP = MD->getOperand(i); - if (!OP) continue; - Value *Mapped_OP = MapValue(OP, VM, Flags, TypeMapper, Materializer); - // Use identity map if Mapped_Op is null and we can ignore missing - // entries. - if (Mapped_OP == OP || - (Mapped_OP == nullptr && (Flags & RF_IgnoreMissingEntries))) - continue; - - // Ok, at least one operand needs remapping. - SmallVector<Value*, 4> Elts; - Elts.reserve(MD->getNumOperands()); - for (i = 0; i != e; ++i) { - Value *Op = MD->getOperand(i); - if (!Op) - Elts.push_back(nullptr); - else { - Value *Mapped_Op = MapValue(Op, VM, Flags, TypeMapper, Materializer); - // Use identity map if Mapped_Op is null and we can ignore missing - // entries. - if (Mapped_Op == nullptr && (Flags & RF_IgnoreMissingEntries)) - Mapped_Op = Op; - Elts.push_back(Mapped_Op); - } - } - MDNode *NewMD = MDNode::get(V->getContext(), Elts); - Dummy->replaceAllUsesWith(NewMD); - VM[V] = NewMD; - MDNode::deleteTemporary(Dummy); - return NewMD; - } + if (!isa<LocalAsMetadata>(MD) && (Flags & RF_NoModuleLevelChanges)) + return VM[V] = const_cast<Value *>(V); - VM[V] = const_cast<Value*>(V); - MDNode::deleteTemporary(Dummy); + auto *MappedMD = MapMetadata(MD, VM, Flags, TypeMapper, Materializer); + if (MD == MappedMD || (!MappedMD && (Flags & RF_IgnoreMissingEntries))) + return VM[V] = const_cast<Value *>(V); - // No operands needed remapping. Use an identity mapping. - return const_cast<Value*>(V); + // FIXME: This assert crashes during bootstrap, but I think it should be + // correct. For now, just match behaviour from before the metadata/value + // split. + // + // assert(MappedMD && "Referenced metadata value not in value map"); + return VM[V] = MetadataAsValue::get(V->getContext(), MappedMD); } // Okay, this either must be a constant (which may or may not be mappable) or @@ -177,6 +144,229 @@ Value *llvm::MapValue(const Value *V, ValueToValueMapTy &VM, RemapFlags Flags, return VM[V] = ConstantPointerNull::get(cast<PointerType>(NewTy)); } +static Metadata *mapToMetadata(ValueToValueMapTy &VM, const Metadata *Key, + Metadata *Val) { + VM.MD()[Key].reset(Val); + return Val; +} + +static Metadata *mapToSelf(ValueToValueMapTy &VM, const Metadata *MD) { + return mapToMetadata(VM, MD, const_cast<Metadata *>(MD)); +} + +static Metadata *MapMetadataImpl(const Metadata *MD, ValueToValueMapTy &VM, + RemapFlags Flags, + ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer); + +static Metadata *mapMetadataOp(Metadata *Op, ValueToValueMapTy &VM, + RemapFlags Flags, + ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer) { + if (!Op) + return nullptr; + if (Metadata *MappedOp = + MapMetadataImpl(Op, VM, Flags, TypeMapper, Materializer)) + return MappedOp; + // Use identity map if MappedOp is null and we can ignore missing entries. + if (Flags & RF_IgnoreMissingEntries) + return Op; + + // FIXME: This assert crashes during bootstrap, but I think it should be + // correct. For now, just match behaviour from before the metadata/value + // split. + // + // llvm_unreachable("Referenced metadata not in value map!"); + return nullptr; +} + +static Metadata *cloneMDTuple(const MDTuple *Node, ValueToValueMapTy &VM, + RemapFlags Flags, + ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer, + bool IsDistinct) { + // Distinct MDTuples have their own code path. + assert(!IsDistinct && "Unexpected distinct tuple"); + (void)IsDistinct; + + SmallVector<Metadata *, 4> Elts; + Elts.reserve(Node->getNumOperands()); + for (unsigned I = 0, E = Node->getNumOperands(); I != E; ++I) + Elts.push_back(mapMetadataOp(Node->getOperand(I), VM, Flags, TypeMapper, + Materializer)); + + return MDTuple::get(Node->getContext(), Elts); +} + +static Metadata *cloneMDLocation(const MDLocation *Node, ValueToValueMapTy &VM, + RemapFlags Flags, + ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer, + bool IsDistinct) { + return (IsDistinct ? MDLocation::getDistinct : MDLocation::get)( + Node->getContext(), Node->getLine(), Node->getColumn(), + mapMetadataOp(Node->getScope(), VM, Flags, TypeMapper, Materializer), + mapMetadataOp(Node->getInlinedAt(), VM, Flags, TypeMapper, Materializer)); +} + +static Metadata *cloneMDNode(const UniquableMDNode *Node, ValueToValueMapTy &VM, + RemapFlags Flags, ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer, bool IsDistinct) { + switch (Node->getMetadataID()) { + default: + llvm_unreachable("Invalid UniquableMDNode subclass"); +#define HANDLE_UNIQUABLE_LEAF(CLASS) \ + case Metadata::CLASS##Kind: \ + return clone##CLASS(cast<CLASS>(Node), VM, Flags, TypeMapper, \ + Materializer, IsDistinct); +#include "llvm/IR/Metadata.def" + } +} + +/// \brief Map a distinct MDNode. +/// +/// Distinct nodes are not uniqued, so they must always recreated. +static Metadata *mapDistinctNode(const UniquableMDNode *Node, + ValueToValueMapTy &VM, RemapFlags Flags, + ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer) { + assert(Node->isDistinct() && "Expected distinct node"); + + // Optimization for MDTuples. + if (isa<MDTuple>(Node)) { + // Create the node first so it's available for cyclical references. + SmallVector<Metadata *, 4> EmptyOps(Node->getNumOperands()); + MDTuple *NewMD = MDTuple::getDistinct(Node->getContext(), EmptyOps); + mapToMetadata(VM, Node, NewMD); + + // Fix the operands. + for (unsigned I = 0, E = Node->getNumOperands(); I != E; ++I) + NewMD->replaceOperandWith(I, mapMetadataOp(Node->getOperand(I), VM, Flags, + TypeMapper, Materializer)); + + return NewMD; + } + + // In general we need a dummy node, since whether the operands are null can + // affect the size of the node. + std::unique_ptr<MDNodeFwdDecl> Dummy( + MDNode::getTemporary(Node->getContext(), None)); + mapToMetadata(VM, Node, Dummy.get()); + Metadata *NewMD = cloneMDNode(Node, VM, Flags, TypeMapper, Materializer, + /* IsDistinct */ true); + Dummy->replaceAllUsesWith(NewMD); + return mapToMetadata(VM, Node, NewMD); +} + +/// \brief Check whether a uniqued node needs to be remapped. +/// +/// Check whether a uniqued node needs to be remapped (due to any operands +/// changing). +static bool shouldRemapUniquedNode(const UniquableMDNode *Node, + ValueToValueMapTy &VM, RemapFlags Flags, + ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer) { + // Check all operands to see if any need to be remapped. + for (unsigned I = 0, E = Node->getNumOperands(); I != E; ++I) { + Metadata *Op = Node->getOperand(I); + if (Op != mapMetadataOp(Op, VM, Flags, TypeMapper, Materializer)) + return true; + } + return false; +} + +/// \brief Map a uniqued MDNode. +/// +/// Uniqued nodes may not need to be recreated (they may map to themselves). +static Metadata *mapUniquedNode(const UniquableMDNode *Node, + ValueToValueMapTy &VM, RemapFlags Flags, + ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer) { + assert(!Node->isDistinct() && "Expected uniqued node"); + + // Create a dummy node in case we have a metadata cycle. + MDNodeFwdDecl *Dummy = MDNode::getTemporary(Node->getContext(), None); + mapToMetadata(VM, Node, Dummy); + + // Check all operands to see if any need to be remapped. + if (!shouldRemapUniquedNode(Node, VM, Flags, TypeMapper, Materializer)) { + // Use an identity mapping. + mapToSelf(VM, Node); + MDNode::deleteTemporary(Dummy); + return const_cast<Metadata *>(static_cast<const Metadata *>(Node)); + } + + // At least one operand needs remapping. + Metadata *NewMD = cloneMDNode(Node, VM, Flags, TypeMapper, Materializer, + /* IsDistinct */ false); + Dummy->replaceAllUsesWith(NewMD); + MDNode::deleteTemporary(Dummy); + return mapToMetadata(VM, Node, NewMD); +} + +static Metadata *MapMetadataImpl(const Metadata *MD, ValueToValueMapTy &VM, + RemapFlags Flags, + ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer) { + // If the value already exists in the map, use it. + if (Metadata *NewMD = VM.MD().lookup(MD).get()) + return NewMD; + + if (isa<MDString>(MD)) + return mapToSelf(VM, MD); + + if (isa<ConstantAsMetadata>(MD)) + if ((Flags & RF_NoModuleLevelChanges)) + return mapToSelf(VM, MD); + + if (const auto *VMD = dyn_cast<ValueAsMetadata>(MD)) { + Value *MappedV = + MapValue(VMD->getValue(), VM, Flags, TypeMapper, Materializer); + if (VMD->getValue() == MappedV || + (!MappedV && (Flags & RF_IgnoreMissingEntries))) + return mapToSelf(VM, MD); + + // FIXME: This assert crashes during bootstrap, but I think it should be + // correct. For now, just match behaviour from before the metadata/value + // split. + // + // assert(MappedV && "Referenced metadata not in value map!"); + if (MappedV) + return mapToMetadata(VM, MD, ValueAsMetadata::get(MappedV)); + return nullptr; + } + + const UniquableMDNode *Node = cast<UniquableMDNode>(MD); + assert(Node->isResolved() && "Unexpected unresolved node"); + + // If this is a module-level metadata and we know that nothing at the + // module level is changing, then use an identity mapping. + if (Flags & RF_NoModuleLevelChanges) + return mapToSelf(VM, MD); + + if (Node->isDistinct()) + return mapDistinctNode(Node, VM, Flags, TypeMapper, Materializer); + + return mapUniquedNode(Node, VM, Flags, TypeMapper, Materializer); +} + +Metadata *llvm::MapMetadata(const Metadata *MD, ValueToValueMapTy &VM, + RemapFlags Flags, ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer) { + Metadata *NewMD = MapMetadataImpl(MD, VM, Flags, TypeMapper, Materializer); + if (NewMD && NewMD != MD) + if (auto *N = dyn_cast<UniquableMDNode>(NewMD)) + N->resolveCycles(); + return NewMD; +} + +MDNode *llvm::MapMetadata(const MDNode *MD, ValueToValueMapTy &VM, + RemapFlags Flags, ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer) { + return cast<MDNode>(MapMetadata(static_cast<const Metadata *>(MD), VM, Flags, + TypeMapper, Materializer)); +} + /// RemapInstruction - Convert the instruction operands from referencing the /// current values into those specified by VMap. /// @@ -210,10 +400,12 @@ void llvm::RemapInstruction(Instruction *I, ValueToValueMapTy &VMap, // Remap attached metadata. SmallVector<std::pair<unsigned, MDNode *>, 4> MDs; I->getAllMetadata(MDs); - for (SmallVectorImpl<std::pair<unsigned, MDNode *> >::iterator - MI = MDs.begin(), ME = MDs.end(); MI != ME; ++MI) { + for (SmallVectorImpl<std::pair<unsigned, MDNode *>>::iterator + MI = MDs.begin(), + ME = MDs.end(); + MI != ME; ++MI) { MDNode *Old = MI->second; - MDNode *New = MapValue(Old, VMap, Flags, TypeMapper, Materializer); + MDNode *New = MapMetadata(Old, VMap, Flags, TypeMapper, Materializer); if (New != Old) I->setMetadata(MI->first, New); } |