diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp | 663 |
1 files changed, 448 insertions, 215 deletions
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index e02d02a05752..f4306bb43dfd 100644 --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -14,28 +14,23 @@ #include "llvm/Transforms/Utils/SimplifyLibCalls.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/SmallString.h" -#include "llvm/ADT/StringMap.h" #include "llvm/ADT/Triple.h" -#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" -#include "llvm/Analysis/ProfileSummaryInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/Analysis/CaptureTracking.h" -#include "llvm/Analysis/Loads.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SizeOpts.h" using namespace llvm; @@ -206,6 +201,11 @@ static Value *copyFlags(const CallInst &Old, Value *New) { return New; } +// Helper to avoid truncating the length if size_t is 32-bits. +static StringRef substr(StringRef Str, uint64_t Len) { + return Len >= Str.size() ? Str : Str.substr(0, Len); +} + //===----------------------------------------------------------------------===// // String and Memory Library Call Optimizations //===----------------------------------------------------------------------===// @@ -242,7 +242,7 @@ Value *LibCallSimplifier::emitStrLenMemCpy(Value *Src, Value *Dst, uint64_t Len, // Now that we have the destination's length, we must index into the // destination's pointer to get the actual memcpy destination (end of // the string .. we're concatenating). - Value *CpyDst = B.CreateGEP(B.getInt8Ty(), Dst, DstLen, "endptr"); + Value *CpyDst = B.CreateInBoundsGEP(B.getInt8Ty(), Dst, DstLen, "endptr"); // We have enough information to now generate the memcpy call to do the // concatenation for us. Make a memcpy to copy the nul byte with align = 1. @@ -326,7 +326,7 @@ Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilderBase &B) { if (!getConstantStringInfo(SrcStr, Str)) { if (CharC->isZero()) // strchr(p, 0) -> p + strlen(p) if (Value *StrLen = emitStrLen(SrcStr, B, DL, TLI)) - return B.CreateGEP(B.getInt8Ty(), SrcStr, StrLen, "strchr"); + return B.CreateInBoundsGEP(B.getInt8Ty(), SrcStr, StrLen, "strchr"); return nullptr; } @@ -339,35 +339,29 @@ Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilderBase &B) { return Constant::getNullValue(CI->getType()); // strchr(s+n,c) -> gep(s+n+i,c) - return B.CreateGEP(B.getInt8Ty(), SrcStr, B.getInt64(I), "strchr"); + return B.CreateInBoundsGEP(B.getInt8Ty(), SrcStr, B.getInt64(I), "strchr"); } Value *LibCallSimplifier::optimizeStrRChr(CallInst *CI, IRBuilderBase &B) { Value *SrcStr = CI->getArgOperand(0); - ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + Value *CharVal = CI->getArgOperand(1); + ConstantInt *CharC = dyn_cast<ConstantInt>(CharVal); annotateNonNullNoUndefBasedOnAccess(CI, 0); - // Cannot fold anything if we're not looking for a constant. - if (!CharC) - return nullptr; - StringRef Str; if (!getConstantStringInfo(SrcStr, Str)) { // strrchr(s, 0) -> strchr(s, 0) - if (CharC->isZero()) + if (CharC && CharC->isZero()) return copyFlags(*CI, emitStrChr(SrcStr, '\0', B, TLI)); return nullptr; } - // Compute the offset. - size_t I = (0xFF & CharC->getSExtValue()) == 0 - ? Str.size() - : Str.rfind(CharC->getSExtValue()); - if (I == StringRef::npos) // Didn't find the char. Return null. - return Constant::getNullValue(CI->getType()); - - // strrchr(s+n,c) -> gep(s+n+i,c) - return B.CreateGEP(B.getInt8Ty(), SrcStr, B.getInt64(I), "strrchr"); + // Try to expand strrchr to the memrchr nonstandard extension if it's + // available, or simply fail otherwise. + uint64_t NBytes = Str.size() + 1; // Include the terminating nul. + Type *IntPtrType = DL.getIntPtrType(CI->getContext()); + Value *Size = ConstantInt::get(IntPtrType, NBytes); + return copyFlags(*CI, emitMemRChr(SrcStr, CharVal, Size, B, DL, TLI)); } Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilderBase &B) { @@ -428,6 +422,12 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilderBase &B) { return nullptr; } +// Optimize a memcmp or, when StrNCmp is true, strncmp call CI with constant +// arrays LHS and RHS and nonconstant Size. +static Value *optimizeMemCmpVarSize(CallInst *CI, Value *LHS, Value *RHS, + Value *Size, bool StrNCmp, + IRBuilderBase &B, const DataLayout &DL); + Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) { Value *Str1P = CI->getArgOperand(0); Value *Str2P = CI->getArgOperand(1); @@ -442,7 +442,7 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) { if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(Size)) Length = LengthArg->getZExtValue(); else - return nullptr; + return optimizeMemCmpVarSize(CI, Str1P, Str2P, Size, true, B, DL); if (Length == 0) // strncmp(x,y,0) -> 0 return ConstantInt::get(CI->getType(), 0); @@ -456,8 +456,9 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) { // strncmp(x, y) -> cnst (if both x and y are constant strings) if (HasStr1 && HasStr2) { - StringRef SubStr1 = Str1.substr(0, Length); - StringRef SubStr2 = Str2.substr(0, Length); + // Avoid truncating the 64-bit Length to 32 bits in ILP32. + StringRef SubStr1 = substr(Str1, Length); + StringRef SubStr2 = substr(Str2, Length); return ConstantInt::get(CI->getType(), SubStr1.compare(SubStr2)); } @@ -557,8 +558,8 @@ Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilderBase &B) { Type *PT = Callee->getFunctionType()->getParamType(0); Value *LenV = ConstantInt::get(DL.getIntPtrType(PT), Len); - Value *DstEnd = B.CreateGEP(B.getInt8Ty(), Dst, - ConstantInt::get(DL.getIntPtrType(PT), Len - 1)); + Value *DstEnd = B.CreateInBoundsGEP( + B.getInt8Ty(), Dst, ConstantInt::get(DL.getIntPtrType(PT), Len - 1)); // We have enough information to now generate the memcpy call to do the // copy for us. Make a memcpy to copy the nul byte with align = 1. @@ -634,12 +635,51 @@ Value *LibCallSimplifier::optimizeStrNCpy(CallInst *CI, IRBuilderBase &B) { } Value *LibCallSimplifier::optimizeStringLength(CallInst *CI, IRBuilderBase &B, - unsigned CharSize) { + unsigned CharSize, + Value *Bound) { Value *Src = CI->getArgOperand(0); + Type *CharTy = B.getIntNTy(CharSize); + + if (isOnlyUsedInZeroEqualityComparison(CI) && + (!Bound || isKnownNonZero(Bound, DL))) { + // Fold strlen: + // strlen(x) != 0 --> *x != 0 + // strlen(x) == 0 --> *x == 0 + // and likewise strnlen with constant N > 0: + // strnlen(x, N) != 0 --> *x != 0 + // strnlen(x, N) == 0 --> *x == 0 + return B.CreateZExt(B.CreateLoad(CharTy, Src, "char0"), + CI->getType()); + } - // Constant folding: strlen("xyz") -> 3 - if (uint64_t Len = GetStringLength(Src, CharSize)) - return ConstantInt::get(CI->getType(), Len - 1); + if (Bound) { + if (ConstantInt *BoundCst = dyn_cast<ConstantInt>(Bound)) { + if (BoundCst->isZero()) + // Fold strnlen(s, 0) -> 0 for any s, constant or otherwise. + return ConstantInt::get(CI->getType(), 0); + + if (BoundCst->isOne()) { + // Fold strnlen(s, 1) -> *s ? 1 : 0 for any s. + Value *CharVal = B.CreateLoad(CharTy, Src, "strnlen.char0"); + Value *ZeroChar = ConstantInt::get(CharTy, 0); + Value *Cmp = B.CreateICmpNE(CharVal, ZeroChar, "strnlen.char0cmp"); + return B.CreateZExt(Cmp, CI->getType()); + } + } + } + + if (uint64_t Len = GetStringLength(Src, CharSize)) { + Value *LenC = ConstantInt::get(CI->getType(), Len - 1); + // Fold strlen("xyz") -> 3 and strnlen("xyz", 2) -> 2 + // and strnlen("xyz", Bound) -> min(3, Bound) for nonconstant Bound. + if (Bound) + return B.CreateBinaryIntrinsic(Intrinsic::umin, LenC, Bound); + return LenC; + } + + if (Bound) + // Punt for strnlen for now. + return nullptr; // If s is a constant pointer pointing to a string literal, we can fold // strlen(s + x) to strlen(s) - x, when x is known to be in the range @@ -650,6 +690,7 @@ Value *LibCallSimplifier::optimizeStringLength(CallInst *CI, IRBuilderBase &B, // very useful because calling strlen for a pointer of other types is // very uncommon. if (GEPOperator *GEP = dyn_cast<GEPOperator>(Src)) { + // TODO: Handle subobjects. if (!isGEPBasedOnPointerToString(GEP, CharSize)) return nullptr; @@ -674,22 +715,15 @@ Value *LibCallSimplifier::optimizeStringLength(CallInst *CI, IRBuilderBase &B, Value *Offset = GEP->getOperand(2); KnownBits Known = computeKnownBits(Offset, DL, 0, nullptr, CI, nullptr); - Known.Zero.flipAllBits(); uint64_t ArrSize = cast<ArrayType>(GEP->getSourceElementType())->getNumElements(); - // KnownZero's bits are flipped, so zeros in KnownZero now represent - // bits known to be zeros in Offset, and ones in KnowZero represent - // bits unknown in Offset. Therefore, Offset is known to be in range - // [0, NullTermIdx] when the flipped KnownZero is non-negative and - // unsigned-less-than NullTermIdx. - // // If Offset is not provably in the range [0, NullTermIdx], we can still // optimize if we can prove that the program has undefined behavior when // Offset is outside that range. That is the case when GEP->getOperand(0) // is a pointer to an object whose memory extent is NullTermIdx+1. - if ((Known.Zero.isNonNegative() && Known.Zero.ule(NullTermIdx)) || - (GEP->isInBounds() && isa<GlobalVariable>(GEP->getOperand(0)) && + if ((Known.isNonNegative() && Known.getMaxValue().ule(NullTermIdx)) || + (isa<GlobalVariable>(GEP->getOperand(0)) && NullTermIdx == ArrSize - 1)) { Offset = B.CreateSExtOrTrunc(Offset, CI->getType()); return B.CreateSub(ConstantInt::get(CI->getType(), NullTermIdx), @@ -713,12 +747,6 @@ Value *LibCallSimplifier::optimizeStringLength(CallInst *CI, IRBuilderBase &B, } } - // strlen(x) != 0 --> *x != 0 - // strlen(x) == 0 --> *x == 0 - if (isOnlyUsedInZeroEqualityComparison(CI)) - return B.CreateZExt(B.CreateLoad(B.getIntNTy(CharSize), Src, "strlenfirst"), - CI->getType()); - return nullptr; } @@ -729,6 +757,16 @@ Value *LibCallSimplifier::optimizeStrLen(CallInst *CI, IRBuilderBase &B) { return nullptr; } +Value *LibCallSimplifier::optimizeStrNLen(CallInst *CI, IRBuilderBase &B) { + Value *Bound = CI->getArgOperand(1); + if (Value *V = optimizeStringLength(CI, B, 8, Bound)) + return V; + + if (isKnownNonZero(Bound, DL)) + annotateNonNullNoUndefBasedOnAccess(CI, 0); + return nullptr; +} + Value *LibCallSimplifier::optimizeWcslen(CallInst *CI, IRBuilderBase &B) { Module &M = *CI->getModule(); unsigned WCharSize = TLI->getWCharSize(M) * 8; @@ -755,8 +793,8 @@ Value *LibCallSimplifier::optimizeStrPBrk(CallInst *CI, IRBuilderBase &B) { if (I == StringRef::npos) // No match. return Constant::getNullValue(CI->getType()); - return B.CreateGEP(B.getInt8Ty(), CI->getArgOperand(0), B.getInt64(I), - "strpbrk"); + return B.CreateInBoundsGEP(B.getInt8Ty(), CI->getArgOperand(0), + B.getInt64(I), "strpbrk"); } // strpbrk(s, "a") -> strchr(s, 'a') @@ -880,35 +918,190 @@ Value *LibCallSimplifier::optimizeStrStr(CallInst *CI, IRBuilderBase &B) { } Value *LibCallSimplifier::optimizeMemRChr(CallInst *CI, IRBuilderBase &B) { - if (isKnownNonZero(CI->getOperand(2), DL)) - annotateNonNullNoUndefBasedOnAccess(CI, 0); - return nullptr; + Value *SrcStr = CI->getArgOperand(0); + Value *Size = CI->getArgOperand(2); + annotateNonNullAndDereferenceable(CI, 0, Size, DL); + Value *CharVal = CI->getArgOperand(1); + ConstantInt *LenC = dyn_cast<ConstantInt>(Size); + Value *NullPtr = Constant::getNullValue(CI->getType()); + + if (LenC) { + if (LenC->isZero()) + // Fold memrchr(x, y, 0) --> null. + return NullPtr; + + if (LenC->isOne()) { + // Fold memrchr(x, y, 1) --> *x == y ? x : null for any x and y, + // constant or otherwise. + Value *Val = B.CreateLoad(B.getInt8Ty(), SrcStr, "memrchr.char0"); + // Slice off the character's high end bits. + CharVal = B.CreateTrunc(CharVal, B.getInt8Ty()); + Value *Cmp = B.CreateICmpEQ(Val, CharVal, "memrchr.char0cmp"); + return B.CreateSelect(Cmp, SrcStr, NullPtr, "memrchr.sel"); + } + } + + StringRef Str; + if (!getConstantStringInfo(SrcStr, Str, 0, /*TrimAtNul=*/false)) + return nullptr; + + if (Str.size() == 0) + // If the array is empty fold memrchr(A, C, N) to null for any value + // of C and N on the basis that the only valid value of N is zero + // (otherwise the call is undefined). + return NullPtr; + + uint64_t EndOff = UINT64_MAX; + if (LenC) { + EndOff = LenC->getZExtValue(); + if (Str.size() < EndOff) + // Punt out-of-bounds accesses to sanitizers and/or libc. + return nullptr; + } + + if (ConstantInt *CharC = dyn_cast<ConstantInt>(CharVal)) { + // Fold memrchr(S, C, N) for a constant C. + size_t Pos = Str.rfind(CharC->getZExtValue(), EndOff); + if (Pos == StringRef::npos) + // When the character is not in the source array fold the result + // to null regardless of Size. + return NullPtr; + + if (LenC) + // Fold memrchr(s, c, N) --> s + Pos for constant N > Pos. + return B.CreateInBoundsGEP(B.getInt8Ty(), SrcStr, B.getInt64(Pos)); + + if (Str.find(Str[Pos]) == Pos) { + // When there is just a single occurrence of C in S, i.e., the one + // in Str[Pos], fold + // memrchr(s, c, N) --> N <= Pos ? null : s + Pos + // for nonconstant N. + Value *Cmp = B.CreateICmpULE(Size, ConstantInt::get(Size->getType(), Pos), + "memrchr.cmp"); + Value *SrcPlus = B.CreateInBoundsGEP(B.getInt8Ty(), SrcStr, + B.getInt64(Pos), "memrchr.ptr_plus"); + return B.CreateSelect(Cmp, NullPtr, SrcPlus, "memrchr.sel"); + } + } + + // Truncate the string to search at most EndOff characters. + Str = Str.substr(0, EndOff); + if (Str.find_first_not_of(Str[0]) != StringRef::npos) + return nullptr; + + // If the source array consists of all equal characters, then for any + // C and N (whether in bounds or not), fold memrchr(S, C, N) to + // N != 0 && *S == C ? S + N - 1 : null + Type *SizeTy = Size->getType(); + Type *Int8Ty = B.getInt8Ty(); + Value *NNeZ = B.CreateICmpNE(Size, ConstantInt::get(SizeTy, 0)); + // Slice off the sought character's high end bits. + CharVal = B.CreateTrunc(CharVal, Int8Ty); + Value *CEqS0 = B.CreateICmpEQ(ConstantInt::get(Int8Ty, Str[0]), CharVal); + Value *And = B.CreateLogicalAnd(NNeZ, CEqS0); + Value *SizeM1 = B.CreateSub(Size, ConstantInt::get(SizeTy, 1)); + Value *SrcPlus = + B.CreateInBoundsGEP(Int8Ty, SrcStr, SizeM1, "memrchr.ptr_plus"); + return B.CreateSelect(And, SrcPlus, NullPtr, "memrchr.sel"); } Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) { Value *SrcStr = CI->getArgOperand(0); Value *Size = CI->getArgOperand(2); - annotateNonNullAndDereferenceable(CI, 0, Size, DL); - ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + if (isKnownNonZero(Size, DL)) + annotateNonNullNoUndefBasedOnAccess(CI, 0); + + Value *CharVal = CI->getArgOperand(1); + ConstantInt *CharC = dyn_cast<ConstantInt>(CharVal); ConstantInt *LenC = dyn_cast<ConstantInt>(Size); + Value *NullPtr = Constant::getNullValue(CI->getType()); // memchr(x, y, 0) -> null if (LenC) { if (LenC->isZero()) - return Constant::getNullValue(CI->getType()); - } else { - // From now on we need at least constant length and string. - return nullptr; + return NullPtr; + + if (LenC->isOne()) { + // Fold memchr(x, y, 1) --> *x == y ? x : null for any x and y, + // constant or otherwise. + Value *Val = B.CreateLoad(B.getInt8Ty(), SrcStr, "memchr.char0"); + // Slice off the character's high end bits. + CharVal = B.CreateTrunc(CharVal, B.getInt8Ty()); + Value *Cmp = B.CreateICmpEQ(Val, CharVal, "memchr.char0cmp"); + return B.CreateSelect(Cmp, SrcStr, NullPtr, "memchr.sel"); + } } StringRef Str; if (!getConstantStringInfo(SrcStr, Str, 0, /*TrimAtNul=*/false)) return nullptr; - // Truncate the string to LenC. If Str is smaller than LenC we will still only - // scan the string, as reading past the end of it is undefined and we can just - // return null if we don't find the char. - Str = Str.substr(0, LenC->getZExtValue()); + if (CharC) { + size_t Pos = Str.find(CharC->getZExtValue()); + if (Pos == StringRef::npos) + // When the character is not in the source array fold the result + // to null regardless of Size. + return NullPtr; + + // Fold memchr(s, c, n) -> n <= Pos ? null : s + Pos + // When the constant Size is less than or equal to the character + // position also fold the result to null. + Value *Cmp = B.CreateICmpULE(Size, ConstantInt::get(Size->getType(), Pos), + "memchr.cmp"); + Value *SrcPlus = B.CreateInBoundsGEP(B.getInt8Ty(), SrcStr, B.getInt64(Pos), + "memchr.ptr"); + return B.CreateSelect(Cmp, NullPtr, SrcPlus); + } + + if (Str.size() == 0) + // If the array is empty fold memchr(A, C, N) to null for any value + // of C and N on the basis that the only valid value of N is zero + // (otherwise the call is undefined). + return NullPtr; + + if (LenC) + Str = substr(Str, LenC->getZExtValue()); + + size_t Pos = Str.find_first_not_of(Str[0]); + if (Pos == StringRef::npos + || Str.find_first_not_of(Str[Pos], Pos) == StringRef::npos) { + // If the source array consists of at most two consecutive sequences + // of the same characters, then for any C and N (whether in bounds or + // not), fold memchr(S, C, N) to + // N != 0 && *S == C ? S : null + // or for the two sequences to: + // N != 0 && *S == C ? S : (N > Pos && S[Pos] == C ? S + Pos : null) + // ^Sel2 ^Sel1 are denoted above. + // The latter makes it also possible to fold strchr() calls with strings + // of the same characters. + Type *SizeTy = Size->getType(); + Type *Int8Ty = B.getInt8Ty(); + + // Slice off the sought character's high end bits. + CharVal = B.CreateTrunc(CharVal, Int8Ty); + + Value *Sel1 = NullPtr; + if (Pos != StringRef::npos) { + // Handle two consecutive sequences of the same characters. + Value *PosVal = ConstantInt::get(SizeTy, Pos); + Value *StrPos = ConstantInt::get(Int8Ty, Str[Pos]); + Value *CEqSPos = B.CreateICmpEQ(CharVal, StrPos); + Value *NGtPos = B.CreateICmp(ICmpInst::ICMP_UGT, Size, PosVal); + Value *And = B.CreateAnd(CEqSPos, NGtPos); + Value *SrcPlus = B.CreateInBoundsGEP(B.getInt8Ty(), SrcStr, PosVal); + Sel1 = B.CreateSelect(And, SrcPlus, NullPtr, "memchr.sel1"); + } + + Value *Str0 = ConstantInt::get(Int8Ty, Str[0]); + Value *CEqS0 = B.CreateICmpEQ(Str0, CharVal); + Value *NNeZ = B.CreateICmpNE(Size, ConstantInt::get(SizeTy, 0)); + Value *And = B.CreateAnd(NNeZ, CEqS0); + return B.CreateSelect(And, SrcStr, Sel1, "memchr.sel2"); + } + + if (!LenC) + // From now on we need a constant length and constant array. + return nullptr; // If the char is variable but the input str and length are not we can turn // this memchr call into a simple bit field test. Of course this only works @@ -920,60 +1113,93 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) { // memchr("\r\n", C, 2) != nullptr -> (1 << C & ((1 << '\r') | (1 << '\n'))) // != 0 // after bounds check. - if (!CharC && !Str.empty() && isOnlyUsedInZeroEqualityComparison(CI)) { - unsigned char Max = - *std::max_element(reinterpret_cast<const unsigned char *>(Str.begin()), - reinterpret_cast<const unsigned char *>(Str.end())); + if (Str.empty() || !isOnlyUsedInZeroEqualityComparison(CI)) + return nullptr; - // Make sure the bit field we're about to create fits in a register on the - // target. - // FIXME: On a 64 bit architecture this prevents us from using the - // interesting range of alpha ascii chars. We could do better by emitting - // two bitfields or shifting the range by 64 if no lower chars are used. - if (!DL.fitsInLegalInteger(Max + 1)) - return nullptr; + unsigned char Max = + *std::max_element(reinterpret_cast<const unsigned char *>(Str.begin()), + reinterpret_cast<const unsigned char *>(Str.end())); - // For the bit field use a power-of-2 type with at least 8 bits to avoid - // creating unnecessary illegal types. - unsigned char Width = NextPowerOf2(std::max((unsigned char)7, Max)); + // Make sure the bit field we're about to create fits in a register on the + // target. + // FIXME: On a 64 bit architecture this prevents us from using the + // interesting range of alpha ascii chars. We could do better by emitting + // two bitfields or shifting the range by 64 if no lower chars are used. + if (!DL.fitsInLegalInteger(Max + 1)) + return nullptr; - // Now build the bit field. - APInt Bitfield(Width, 0); - for (char C : Str) - Bitfield.setBit((unsigned char)C); - Value *BitfieldC = B.getInt(Bitfield); + // For the bit field use a power-of-2 type with at least 8 bits to avoid + // creating unnecessary illegal types. + unsigned char Width = NextPowerOf2(std::max((unsigned char)7, Max)); - // Adjust width of "C" to the bitfield width, then mask off the high bits. - Value *C = B.CreateZExtOrTrunc(CI->getArgOperand(1), BitfieldC->getType()); - C = B.CreateAnd(C, B.getIntN(Width, 0xFF)); + // Now build the bit field. + APInt Bitfield(Width, 0); + for (char C : Str) + Bitfield.setBit((unsigned char)C); + Value *BitfieldC = B.getInt(Bitfield); - // First check that the bit field access is within bounds. - Value *Bounds = B.CreateICmp(ICmpInst::ICMP_ULT, C, B.getIntN(Width, Width), - "memchr.bounds"); + // Adjust width of "C" to the bitfield width, then mask off the high bits. + Value *C = B.CreateZExtOrTrunc(CharVal, BitfieldC->getType()); + C = B.CreateAnd(C, B.getIntN(Width, 0xFF)); - // Create code that checks if the given bit is set in the field. - Value *Shl = B.CreateShl(B.getIntN(Width, 1ULL), C); - Value *Bits = B.CreateIsNotNull(B.CreateAnd(Shl, BitfieldC), "memchr.bits"); + // First check that the bit field access is within bounds. + Value *Bounds = B.CreateICmp(ICmpInst::ICMP_ULT, C, B.getIntN(Width, Width), + "memchr.bounds"); - // Finally merge both checks and cast to pointer type. The inttoptr - // implicitly zexts the i1 to intptr type. - return B.CreateIntToPtr(B.CreateLogicalAnd(Bounds, Bits, "memchr"), - CI->getType()); - } + // Create code that checks if the given bit is set in the field. + Value *Shl = B.CreateShl(B.getIntN(Width, 1ULL), C); + Value *Bits = B.CreateIsNotNull(B.CreateAnd(Shl, BitfieldC), "memchr.bits"); - // Check if all arguments are constants. If so, we can constant fold. - if (!CharC) - return nullptr; + // Finally merge both checks and cast to pointer type. The inttoptr + // implicitly zexts the i1 to intptr type. + return B.CreateIntToPtr(B.CreateLogicalAnd(Bounds, Bits, "memchr"), + CI->getType()); +} - // Compute the offset. - size_t I = Str.find(CharC->getSExtValue() & 0xFF); - if (I == StringRef::npos) // Didn't find the char. memchr returns null. +// Optimize a memcmp or, when StrNCmp is true, strncmp call CI with constant +// arrays LHS and RHS and nonconstant Size. +static Value *optimizeMemCmpVarSize(CallInst *CI, Value *LHS, Value *RHS, + Value *Size, bool StrNCmp, + IRBuilderBase &B, const DataLayout &DL) { + if (LHS == RHS) // memcmp(s,s,x) -> 0 return Constant::getNullValue(CI->getType()); - // memchr(s+n,c,l) -> gep(s+n+i,c) - return B.CreateGEP(B.getInt8Ty(), SrcStr, B.getInt64(I), "memchr"); + StringRef LStr, RStr; + if (!getConstantStringInfo(LHS, LStr, 0, /*TrimAtNul=*/false) || + !getConstantStringInfo(RHS, RStr, 0, /*TrimAtNul=*/false)) + return nullptr; + + // If the contents of both constant arrays are known, fold a call to + // memcmp(A, B, N) to + // N <= Pos ? 0 : (A < B ? -1 : B < A ? +1 : 0) + // where Pos is the first mismatch between A and B, determined below. + + uint64_t Pos = 0; + Value *Zero = ConstantInt::get(CI->getType(), 0); + for (uint64_t MinSize = std::min(LStr.size(), RStr.size()); ; ++Pos) { + if (Pos == MinSize || + (StrNCmp && (LStr[Pos] == '\0' && RStr[Pos] == '\0'))) { + // One array is a leading part of the other of equal or greater + // size, or for strncmp, the arrays are equal strings. + // Fold the result to zero. Size is assumed to be in bounds, since + // otherwise the call would be undefined. + return Zero; + } + + if (LStr[Pos] != RStr[Pos]) + break; + } + + // Normalize the result. + typedef unsigned char UChar; + int IRes = UChar(LStr[Pos]) < UChar(RStr[Pos]) ? -1 : 1; + Value *MaxSize = ConstantInt::get(Size->getType(), Pos); + Value *Cmp = B.CreateICmp(ICmpInst::ICMP_ULE, Size, MaxSize); + Value *Res = ConstantInt::get(CI->getType(), IRes); + return B.CreateSelect(Cmp, Zero, Res); } +// Optimize a memcmp call CI with constant size Len. static Value *optimizeMemCmpConstantSize(CallInst *CI, Value *LHS, Value *RHS, uint64_t Len, IRBuilderBase &B, const DataLayout &DL) { @@ -1028,25 +1254,6 @@ static Value *optimizeMemCmpConstantSize(CallInst *CI, Value *LHS, Value *RHS, } } - // Constant folding: memcmp(x, y, Len) -> constant (all arguments are const). - // TODO: This is limited to i8 arrays. - StringRef LHSStr, RHSStr; - if (getConstantStringInfo(LHS, LHSStr) && - getConstantStringInfo(RHS, RHSStr)) { - // Make sure we're not reading out-of-bounds memory. - if (Len > LHSStr.size() || Len > RHSStr.size()) - return nullptr; - // Fold the memcmp and normalize the result. This way we get consistent - // results across multiple platforms. - uint64_t Ret = 0; - int Cmp = memcmp(LHSStr.data(), RHSStr.data(), Len); - if (Cmp < 0) - Ret = -1; - else if (Cmp > 0) - Ret = 1; - return ConstantInt::get(CI->getType(), Ret); - } - return nullptr; } @@ -1056,33 +1263,29 @@ Value *LibCallSimplifier::optimizeMemCmpBCmpCommon(CallInst *CI, Value *LHS = CI->getArgOperand(0), *RHS = CI->getArgOperand(1); Value *Size = CI->getArgOperand(2); - if (LHS == RHS) // memcmp(s,s,x) -> 0 - return Constant::getNullValue(CI->getType()); - annotateNonNullAndDereferenceable(CI, {0, 1}, Size, DL); - // Handle constant lengths. + + if (Value *Res = optimizeMemCmpVarSize(CI, LHS, RHS, Size, false, B, DL)) + return Res; + + // Handle constant Size. ConstantInt *LenC = dyn_cast<ConstantInt>(Size); if (!LenC) return nullptr; - // memcmp(d,s,0) -> 0 - if (LenC->getZExtValue() == 0) - return Constant::getNullValue(CI->getType()); - - if (Value *Res = - optimizeMemCmpConstantSize(CI, LHS, RHS, LenC->getZExtValue(), B, DL)) - return Res; - return nullptr; + return optimizeMemCmpConstantSize(CI, LHS, RHS, LenC->getZExtValue(), B, DL); } Value *LibCallSimplifier::optimizeMemCmp(CallInst *CI, IRBuilderBase &B) { + Module *M = CI->getModule(); if (Value *V = optimizeMemCmpBCmpCommon(CI, B)) return V; // memcmp(x, y, Len) == 0 -> bcmp(x, y, Len) == 0 // bcmp can be more efficient than memcmp because it only has to know that // there is a difference, not how different one is to the other. - if (TLI->has(LibFunc_bcmp) && isOnlyUsedInZeroEqualityComparison(CI)) { + if (isLibFuncEmittable(M, TLI, LibFunc_bcmp) && + isOnlyUsedInZeroEqualityComparison(CI)) { Value *LHS = CI->getArgOperand(0); Value *RHS = CI->getArgOperand(1); Value *Size = CI->getArgOperand(2); @@ -1125,6 +1328,7 @@ Value *LibCallSimplifier::optimizeMemCCpy(CallInst *CI, IRBuilderBase &B) { return Constant::getNullValue(CI->getType()); if (!getConstantStringInfo(Src, SrcStr, /*Offset=*/0, /*TrimAtNul=*/false) || + // TODO: Handle zeroinitializer. !StopChar) return nullptr; } else { @@ -1246,7 +1450,8 @@ static Value *valueHasFloatPrecision(Value *Val) { /// Shrink double -> float functions. static Value *optimizeDoubleFP(CallInst *CI, IRBuilderBase &B, - bool isBinary, bool isPrecise = false) { + bool isBinary, const TargetLibraryInfo *TLI, + bool isPrecise = false) { Function *CalleeFn = CI->getCalledFunction(); if (!CI->getType()->isDoubleTy() || !CalleeFn) return nullptr; @@ -1296,22 +1501,25 @@ static Value *optimizeDoubleFP(CallInst *CI, IRBuilderBase &B, R = isBinary ? B.CreateCall(Fn, V) : B.CreateCall(Fn, V[0]); } else { AttributeList CalleeAttrs = CalleeFn->getAttributes(); - R = isBinary ? emitBinaryFloatFnCall(V[0], V[1], CalleeName, B, CalleeAttrs) - : emitUnaryFloatFnCall(V[0], CalleeName, B, CalleeAttrs); + R = isBinary ? emitBinaryFloatFnCall(V[0], V[1], TLI, CalleeName, B, + CalleeAttrs) + : emitUnaryFloatFnCall(V[0], TLI, CalleeName, B, CalleeAttrs); } return B.CreateFPExt(R, B.getDoubleTy()); } /// Shrink double -> float for unary functions. static Value *optimizeUnaryDoubleFP(CallInst *CI, IRBuilderBase &B, + const TargetLibraryInfo *TLI, bool isPrecise = false) { - return optimizeDoubleFP(CI, B, false, isPrecise); + return optimizeDoubleFP(CI, B, false, TLI, isPrecise); } /// Shrink double -> float for binary functions. static Value *optimizeBinaryDoubleFP(CallInst *CI, IRBuilderBase &B, + const TargetLibraryInfo *TLI, bool isPrecise = false) { - return optimizeDoubleFP(CI, B, true, isPrecise); + return optimizeDoubleFP(CI, B, true, TLI, isPrecise); } // cabs(z) -> sqrt((creal(z)*creal(z)) + (cimag(z)*cimag(z))) @@ -1427,6 +1635,7 @@ static Value *getIntToFPVal(Value *I2F, IRBuilderBase &B, unsigned DstWidth) { /// ldexp(1.0, x) for pow(2.0, itofp(x)); exp2(n * x) for pow(2.0 ** n, x); /// exp10(x) for pow(10.0, x); exp2(log2(n) * x) for pow(n, x). Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { + Module *M = Pow->getModule(); Value *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); AttributeList Attrs; // Attributes are only meaningful on the original call Module *Mod = Pow->getModule(); @@ -1454,7 +1663,8 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { Function *CalleeFn = BaseFn->getCalledFunction(); if (CalleeFn && - TLI->getLibFunc(CalleeFn->getName(), LibFn) && TLI->has(LibFn)) { + TLI->getLibFunc(CalleeFn->getName(), LibFn) && + isLibFuncEmittable(M, TLI, LibFn)) { StringRef ExpName; Intrinsic::ID ID; Value *ExpFn; @@ -1506,7 +1716,7 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { // pow(2.0, itofp(x)) -> ldexp(1.0, x) if (match(Base, m_SpecificFP(2.0)) && (isa<SIToFPInst>(Expo) || isa<UIToFPInst>(Expo)) && - hasFloatFn(TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl)) { + hasFloatFn(M, TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl)) { if (Value *ExpoI = getIntToFPVal(Expo, B, TLI->getIntSize())) return copyFlags(*Pow, emitBinaryFloatFnCall(ConstantFP::get(Ty, 1.0), ExpoI, @@ -1515,7 +1725,7 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { } // pow(2.0 ** n, x) -> exp2(n * x) - if (hasFloatFn(TLI, Ty, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l)) { + if (hasFloatFn(M, TLI, Ty, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l)) { APFloat BaseR = APFloat(1.0); BaseR.convert(BaseF->getSemantics(), APFloat::rmTowardZero, &Ignored); BaseR = BaseR / *BaseF; @@ -1542,7 +1752,7 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { // pow(10.0, x) -> exp10(x) // TODO: There is no exp10() intrinsic yet, but some day there shall be one. if (match(Base, m_SpecificFP(10.0)) && - hasFloatFn(TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) + hasFloatFn(M, TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) return copyFlags(*Pow, emitUnaryFloatFnCall(Expo, TLI, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l, B, Attrs)); @@ -1567,7 +1777,8 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { return copyFlags(*Pow, B.CreateCall(Intrinsic::getDeclaration( Mod, Intrinsic::exp2, Ty), FMul, "exp2")); - else if (hasFloatFn(TLI, Ty, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l)) + else if (hasFloatFn(M, TLI, Ty, LibFunc_exp2, LibFunc_exp2f, + LibFunc_exp2l)) return copyFlags(*Pow, emitUnaryFloatFnCall(FMul, TLI, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l, B, Attrs)); @@ -1588,7 +1799,8 @@ static Value *getSqrtCall(Value *V, AttributeList Attrs, bool NoErrno, } // Otherwise, use the libcall for sqrt(). - if (hasFloatFn(TLI, V->getType(), LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) + if (hasFloatFn(M, TLI, V->getType(), LibFunc_sqrt, LibFunc_sqrtf, + LibFunc_sqrtl)) // TODO: We also should check that the target can in fact lower the sqrt() // libcall. We currently have no way to ask this question, so we ask if // the target has a sqrt() libcall, which is not exactly the same. @@ -1778,8 +1990,8 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilderBase &B) { // Shrink pow() to powf() if the arguments are single precision, // unless the result is expected to be double precision. if (UnsafeFPShrink && Name == TLI->getName(LibFunc_pow) && - hasFloatVersion(Name)) { - if (Value *Shrunk = optimizeBinaryDoubleFP(Pow, B, true)) + hasFloatVersion(M, Name)) { + if (Value *Shrunk = optimizeBinaryDoubleFP(Pow, B, TLI, true)) return Shrunk; } @@ -1787,13 +1999,14 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilderBase &B) { } Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilderBase &B) { + Module *M = CI->getModule(); Function *Callee = CI->getCalledFunction(); AttributeList Attrs; // Attributes are only meaningful on the original call StringRef Name = Callee->getName(); Value *Ret = nullptr; if (UnsafeFPShrink && Name == TLI->getName(LibFunc_exp2) && - hasFloatVersion(Name)) - Ret = optimizeUnaryDoubleFP(CI, B, true); + hasFloatVersion(M, Name)) + Ret = optimizeUnaryDoubleFP(CI, B, TLI, true); Type *Ty = CI->getType(); Value *Op = CI->getArgOperand(0); @@ -1801,7 +2014,7 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilderBase &B) { // Turn exp2(sitofp(x)) -> ldexp(1.0, sext(x)) if sizeof(x) <= IntSize // Turn exp2(uitofp(x)) -> ldexp(1.0, zext(x)) if sizeof(x) < IntSize if ((isa<SIToFPInst>(Op) || isa<UIToFPInst>(Op)) && - hasFloatFn(TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl)) { + hasFloatFn(M, TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl)) { if (Value *Exp = getIntToFPVal(Op, B, TLI->getIntSize())) return emitBinaryFloatFnCall(ConstantFP::get(Ty, 1.0), Exp, TLI, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl, @@ -1812,12 +2025,14 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilderBase &B) { } Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilderBase &B) { + Module *M = CI->getModule(); + // If we can shrink the call to a float function rather than a double // function, do that first. Function *Callee = CI->getCalledFunction(); StringRef Name = Callee->getName(); - if ((Name == "fmin" || Name == "fmax") && hasFloatVersion(Name)) - if (Value *Ret = optimizeBinaryDoubleFP(CI, B)) + if ((Name == "fmin" || Name == "fmax") && hasFloatVersion(M, Name)) + if (Value *Ret = optimizeBinaryDoubleFP(CI, B, TLI)) return Ret; // The LLVM intrinsics minnum/maxnum correspond to fmin/fmax. Canonicalize to @@ -1848,8 +2063,8 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { Type *Ty = Log->getType(); Value *Ret = nullptr; - if (UnsafeFPShrink && hasFloatVersion(LogNm)) - Ret = optimizeUnaryDoubleFP(Log, B, true); + if (UnsafeFPShrink && hasFloatVersion(Mod, LogNm)) + Ret = optimizeUnaryDoubleFP(Log, B, TLI, true); // The earlier call must also be 'fast' in order to do these transforms. CallInst *Arg = dyn_cast<CallInst>(Log->getArgOperand(0)); @@ -1957,7 +2172,7 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { Log->doesNotAccessMemory() ? B.CreateCall(Intrinsic::getDeclaration(Mod, LogID, Ty), Arg->getOperand(0), "log") - : emitUnaryFloatFnCall(Arg->getOperand(0), LogNm, B, Attrs); + : emitUnaryFloatFnCall(Arg->getOperand(0), TLI, LogNm, B, Attrs); Value *MulY = B.CreateFMul(Arg->getArgOperand(1), LogX, "mul"); // Since pow() may have side effects, e.g. errno, // dead code elimination may not be trusted to remove it. @@ -1980,7 +2195,7 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { Value *LogE = Log->doesNotAccessMemory() ? B.CreateCall(Intrinsic::getDeclaration(Mod, LogID, Ty), Eul, "log") - : emitUnaryFloatFnCall(Eul, LogNm, B, Attrs); + : emitUnaryFloatFnCall(Eul, TLI, LogNm, B, Attrs); Value *MulY = B.CreateFMul(Arg->getArgOperand(0), LogE, "mul"); // Since exp() may have side effects, e.g. errno, // dead code elimination may not be trusted to remove it. @@ -1992,14 +2207,16 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { } Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) { + Module *M = CI->getModule(); Function *Callee = CI->getCalledFunction(); Value *Ret = nullptr; // TODO: Once we have a way (other than checking for the existince of the // libcall) to tell whether our target can lower @llvm.sqrt, relax the // condition below. - if (TLI->has(LibFunc_sqrtf) && (Callee->getName() == "sqrt" || - Callee->getIntrinsicID() == Intrinsic::sqrt)) - Ret = optimizeUnaryDoubleFP(CI, B, true); + if (isLibFuncEmittable(M, TLI, LibFunc_sqrtf) && + (Callee->getName() == "sqrt" || + Callee->getIntrinsicID() == Intrinsic::sqrt)) + Ret = optimizeUnaryDoubleFP(CI, B, TLI, true); if (!CI->isFast()) return Ret; @@ -2044,7 +2261,6 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) { // If we found a repeated factor, hoist it out of the square root and // replace it with the fabs of that factor. - Module *M = Callee->getParent(); Type *ArgType = I->getType(); Function *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, ArgType); Value *FabsCall = B.CreateCall(Fabs, RepeatOp, "fabs"); @@ -2061,11 +2277,12 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) { // TODO: Generalize to handle any trig function and its inverse. Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) { + Module *M = CI->getModule(); Function *Callee = CI->getCalledFunction(); Value *Ret = nullptr; StringRef Name = Callee->getName(); - if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(Name)) - Ret = optimizeUnaryDoubleFP(CI, B, true); + if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(M, Name)) + Ret = optimizeUnaryDoubleFP(CI, B, TLI, true); Value *Op1 = CI->getArgOperand(0); auto *OpC = dyn_cast<CallInst>(Op1); @@ -2081,7 +2298,8 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) { // tanl(atanl(x)) -> x LibFunc Func; Function *F = OpC->getCalledFunction(); - if (F && TLI->getLibFunc(F->getName(), Func) && TLI->has(Func) && + if (F && TLI->getLibFunc(F->getName(), Func) && + isLibFuncEmittable(M, TLI, Func) && ((Func == LibFunc_atan && Callee->getName() == "tan") || (Func == LibFunc_atanf && Callee->getName() == "tanf") || (Func == LibFunc_atanl && Callee->getName() == "tanl"))) @@ -2097,9 +2315,10 @@ static bool isTrigLibCall(CallInst *CI) { CI->hasFnAttr(Attribute::ReadNone); } -static void insertSinCosCall(IRBuilderBase &B, Function *OrigCallee, Value *Arg, +static bool insertSinCosCall(IRBuilderBase &B, Function *OrigCallee, Value *Arg, bool UseFloat, Value *&Sin, Value *&Cos, - Value *&SinCos) { + Value *&SinCos, const TargetLibraryInfo *TLI) { + Module *M = OrigCallee->getParent(); Type *ArgTy = Arg->getType(); Type *ResTy; StringRef Name; @@ -2119,9 +2338,12 @@ static void insertSinCosCall(IRBuilderBase &B, Function *OrigCallee, Value *Arg, ResTy = StructType::get(ArgTy, ArgTy); } - Module *M = OrigCallee->getParent(); - FunctionCallee Callee = - M->getOrInsertFunction(Name, OrigCallee->getAttributes(), ResTy, ArgTy); + if (!isLibFuncEmittable(M, TLI, Name)) + return false; + LibFunc TheLibFunc; + TLI->getLibFunc(Name, TheLibFunc); + FunctionCallee Callee = getOrInsertLibFunc( + M, *TLI, TheLibFunc, OrigCallee->getAttributes(), ResTy, ArgTy); if (Instruction *ArgInst = dyn_cast<Instruction>(Arg)) { // If the argument is an instruction, it must dominate all uses so put our @@ -2145,6 +2367,8 @@ static void insertSinCosCall(IRBuilderBase &B, Function *OrigCallee, Value *Arg, Cos = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 1), "cospi"); } + + return true; } Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, IRBuilderBase &B) { @@ -2172,7 +2396,9 @@ Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, IRBuilderBase &B) { return nullptr; Value *Sin, *Cos, *SinCos; - insertSinCosCall(B, CI->getCalledFunction(), Arg, IsFloat, Sin, Cos, SinCos); + if (!insertSinCosCall(B, CI->getCalledFunction(), Arg, IsFloat, Sin, Cos, + SinCos, TLI)) + return nullptr; auto replaceTrigInsts = [this](SmallVectorImpl<CallInst *> &Calls, Value *Res) { @@ -2193,6 +2419,7 @@ void LibCallSimplifier::classifyArgUse( SmallVectorImpl<CallInst *> &CosCalls, SmallVectorImpl<CallInst *> &SinCosCalls) { CallInst *CI = dyn_cast<CallInst>(Val); + Module *M = CI->getModule(); if (!CI || CI->use_empty()) return; @@ -2203,7 +2430,8 @@ void LibCallSimplifier::classifyArgUse( Function *Callee = CI->getCalledFunction(); LibFunc Func; - if (!Callee || !TLI->getLibFunc(*Callee, Func) || !TLI->has(Func) || + if (!Callee || !TLI->getLibFunc(*Callee, Func) || + !isLibFuncEmittable(M, TLI, Func) || !isTrigLibCall(CI)) return; @@ -2258,7 +2486,7 @@ Value *LibCallSimplifier::optimizeAbs(CallInst *CI, IRBuilderBase &B) { // abs(x) -> x <s 0 ? -x : x // The negation has 'nsw' because abs of INT_MIN is undefined. Value *X = CI->getArgOperand(0); - Value *IsNeg = B.CreateICmpSLT(X, Constant::getNullValue(X->getType())); + Value *IsNeg = B.CreateIsNeg(X); Value *NegX = B.CreateNSWNeg(X, "neg"); return B.CreateSelect(IsNeg, NegX, X); } @@ -2418,6 +2646,7 @@ Value *LibCallSimplifier::optimizePrintFString(CallInst *CI, IRBuilderBase &B) { Value *LibCallSimplifier::optimizePrintF(CallInst *CI, IRBuilderBase &B) { + Module *M = CI->getModule(); Function *Callee = CI->getCalledFunction(); FunctionType *FT = Callee->getFunctionType(); if (Value *V = optimizePrintFString(CI, B)) { @@ -2426,10 +2655,10 @@ Value *LibCallSimplifier::optimizePrintF(CallInst *CI, IRBuilderBase &B) { // printf(format, ...) -> iprintf(format, ...) if no floating point // arguments. - if (TLI->has(LibFunc_iprintf) && !callHasFloatingPointArgument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); - FunctionCallee IPrintFFn = - M->getOrInsertFunction("iprintf", FT, Callee->getAttributes()); + if (isLibFuncEmittable(M, TLI, LibFunc_iprintf) && + !callHasFloatingPointArgument(CI)) { + FunctionCallee IPrintFFn = getOrInsertLibFunc(M, *TLI, LibFunc_iprintf, FT, + Callee->getAttributes()); CallInst *New = cast<CallInst>(CI->clone()); New->setCalledFunction(IPrintFFn); B.Insert(New); @@ -2438,11 +2667,10 @@ Value *LibCallSimplifier::optimizePrintF(CallInst *CI, IRBuilderBase &B) { // printf(format, ...) -> __small_printf(format, ...) if no 128-bit floating point // arguments. - if (TLI->has(LibFunc_small_printf) && !callHasFP128Argument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); - auto SmallPrintFFn = - M->getOrInsertFunction(TLI->getName(LibFunc_small_printf), - FT, Callee->getAttributes()); + if (isLibFuncEmittable(M, TLI, LibFunc_small_printf) && + !callHasFP128Argument(CI)) { + auto SmallPrintFFn = getOrInsertLibFunc(M, *TLI, LibFunc_small_printf, FT, + Callee->getAttributes()); CallInst *New = cast<CallInst>(CI->clone()); New->setCalledFunction(SmallPrintFFn); B.Insert(New); @@ -2489,7 +2717,7 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI, Value *V = B.CreateTrunc(CI->getArgOperand(2), B.getInt8Ty(), "char"); Value *Ptr = castToCStr(Dest, B); B.CreateStore(V, Ptr); - Ptr = B.CreateGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul"); + Ptr = B.CreateInBoundsGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul"); B.CreateStore(B.getInt8(0), Ptr); return ConstantInt::get(CI->getType(), 1); @@ -2541,6 +2769,7 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI, } Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilderBase &B) { + Module *M = CI->getModule(); Function *Callee = CI->getCalledFunction(); FunctionType *FT = Callee->getFunctionType(); if (Value *V = optimizeSPrintFString(CI, B)) { @@ -2549,10 +2778,10 @@ Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilderBase &B) { // sprintf(str, format, ...) -> siprintf(str, format, ...) if no floating // point arguments. - if (TLI->has(LibFunc_siprintf) && !callHasFloatingPointArgument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); - FunctionCallee SIPrintFFn = - M->getOrInsertFunction("siprintf", FT, Callee->getAttributes()); + if (isLibFuncEmittable(M, TLI, LibFunc_siprintf) && + !callHasFloatingPointArgument(CI)) { + FunctionCallee SIPrintFFn = getOrInsertLibFunc(M, *TLI, LibFunc_siprintf, + FT, Callee->getAttributes()); CallInst *New = cast<CallInst>(CI->clone()); New->setCalledFunction(SIPrintFFn); B.Insert(New); @@ -2561,11 +2790,10 @@ Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilderBase &B) { // sprintf(str, format, ...) -> __small_sprintf(str, format, ...) if no 128-bit // floating point arguments. - if (TLI->has(LibFunc_small_sprintf) && !callHasFP128Argument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); - auto SmallSPrintFFn = - M->getOrInsertFunction(TLI->getName(LibFunc_small_sprintf), - FT, Callee->getAttributes()); + if (isLibFuncEmittable(M, TLI, LibFunc_small_sprintf) && + !callHasFP128Argument(CI)) { + auto SmallSPrintFFn = getOrInsertLibFunc(M, *TLI, LibFunc_small_sprintf, FT, + Callee->getAttributes()); CallInst *New = cast<CallInst>(CI->clone()); New->setCalledFunction(SmallSPrintFFn); B.Insert(New); @@ -2629,7 +2857,7 @@ Value *LibCallSimplifier::optimizeSnPrintFString(CallInst *CI, Value *V = B.CreateTrunc(CI->getArgOperand(3), B.getInt8Ty(), "char"); Value *Ptr = castToCStr(CI->getArgOperand(0), B); B.CreateStore(V, Ptr); - Ptr = B.CreateGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul"); + Ptr = B.CreateInBoundsGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul"); B.CreateStore(B.getInt8(0), Ptr); return ConstantInt::get(CI->getType(), 1); @@ -2721,6 +2949,7 @@ Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI, } Value *LibCallSimplifier::optimizeFPrintF(CallInst *CI, IRBuilderBase &B) { + Module *M = CI->getModule(); Function *Callee = CI->getCalledFunction(); FunctionType *FT = Callee->getFunctionType(); if (Value *V = optimizeFPrintFString(CI, B)) { @@ -2729,10 +2958,10 @@ Value *LibCallSimplifier::optimizeFPrintF(CallInst *CI, IRBuilderBase &B) { // fprintf(stream, format, ...) -> fiprintf(stream, format, ...) if no // floating point arguments. - if (TLI->has(LibFunc_fiprintf) && !callHasFloatingPointArgument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); - FunctionCallee FIPrintFFn = - M->getOrInsertFunction("fiprintf", FT, Callee->getAttributes()); + if (isLibFuncEmittable(M, TLI, LibFunc_fiprintf) && + !callHasFloatingPointArgument(CI)) { + FunctionCallee FIPrintFFn = getOrInsertLibFunc(M, *TLI, LibFunc_fiprintf, + FT, Callee->getAttributes()); CallInst *New = cast<CallInst>(CI->clone()); New->setCalledFunction(FIPrintFFn); B.Insert(New); @@ -2741,11 +2970,11 @@ Value *LibCallSimplifier::optimizeFPrintF(CallInst *CI, IRBuilderBase &B) { // fprintf(stream, format, ...) -> __small_fprintf(stream, format, ...) if no // 128-bit floating point arguments. - if (TLI->has(LibFunc_small_fprintf) && !callHasFP128Argument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); + if (isLibFuncEmittable(M, TLI, LibFunc_small_fprintf) && + !callHasFP128Argument(CI)) { auto SmallFPrintFFn = - M->getOrInsertFunction(TLI->getName(LibFunc_small_fprintf), - FT, Callee->getAttributes()); + getOrInsertLibFunc(M, *TLI, LibFunc_small_fprintf, FT, + Callee->getAttributes()); CallInst *New = cast<CallInst>(CI->clone()); New->setCalledFunction(SmallFPrintFFn); B.Insert(New); @@ -2830,21 +3059,19 @@ Value *LibCallSimplifier::optimizeBCopy(CallInst *CI, IRBuilderBase &B) { CI->getArgOperand(2))); } -bool LibCallSimplifier::hasFloatVersion(StringRef FuncName) { - LibFunc Func; +bool LibCallSimplifier::hasFloatVersion(const Module *M, StringRef FuncName) { SmallString<20> FloatFuncName = FuncName; FloatFuncName += 'f'; - if (TLI->getLibFunc(FloatFuncName, Func)) - return TLI->has(Func); - return false; + return isLibFuncEmittable(M, TLI, FloatFuncName); } Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, IRBuilderBase &Builder) { + Module *M = CI->getModule(); LibFunc Func; Function *Callee = CI->getCalledFunction(); // Check for string/memory library functions. - if (TLI->getLibFunc(*Callee, Func) && TLI->has(Func)) { + if (TLI->getLibFunc(*Callee, Func) && isLibFuncEmittable(M, TLI, Func)) { // Make sure we never change the calling convention. assert( (ignoreCallingConv(Func) || @@ -2871,6 +3098,8 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, return optimizeStrNCpy(CI, Builder); case LibFunc_strlen: return optimizeStrLen(CI, Builder); + case LibFunc_strnlen: + return optimizeStrNLen(CI, Builder); case LibFunc_strpbrk: return optimizeStrPBrk(CI, Builder); case LibFunc_strndup: @@ -2923,6 +3152,8 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func, IRBuilderBase &Builder) { + const Module *M = CI->getModule(); + // Don't optimize calls that require strict floating point semantics. if (CI->isStrictFP()) return nullptr; @@ -3001,12 +3232,12 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_sin: case LibFunc_sinh: case LibFunc_tanh: - if (UnsafeFPShrink && hasFloatVersion(CI->getCalledFunction()->getName())) - return optimizeUnaryDoubleFP(CI, Builder, true); + if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName())) + return optimizeUnaryDoubleFP(CI, Builder, TLI, true); return nullptr; case LibFunc_copysign: - if (hasFloatVersion(CI->getCalledFunction()->getName())) - return optimizeBinaryDoubleFP(CI, Builder); + if (hasFloatVersion(M, CI->getCalledFunction()->getName())) + return optimizeBinaryDoubleFP(CI, Builder, TLI); return nullptr; case LibFunc_fminf: case LibFunc_fmin: @@ -3025,6 +3256,7 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, } Value *LibCallSimplifier::optimizeCall(CallInst *CI, IRBuilderBase &Builder) { + Module *M = CI->getModule(); assert(!CI->isMustTailCall() && "These transforms aren't musttail safe."); // TODO: Split out the code below that operates on FP calls so that @@ -3103,7 +3335,7 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI, IRBuilderBase &Builder) { } // Then check for known library functions. - if (TLI->getLibFunc(*Callee, Func) && TLI->has(Func)) { + if (TLI->getLibFunc(*Callee, Func) && isLibFuncEmittable(M, TLI, Func)) { // We never change the calling convention. if (!ignoreCallingConv(Func) && !IsCallingConvC) return nullptr; @@ -3170,7 +3402,7 @@ LibCallSimplifier::LibCallSimplifier( function_ref<void(Instruction *, Value *)> Replacer, function_ref<void(Instruction *)> Eraser) : FortifiedSimplifier(TLI), DL(DL), TLI(TLI), ORE(ORE), BFI(BFI), PSI(PSI), - UnsafeFPShrink(false), Replacer(Replacer), Eraser(Eraser) {} + Replacer(Replacer), Eraser(Eraser) {} void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) { // Indirect through the replacer used in this instance. @@ -3361,7 +3593,8 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI, // If the function was an __stpcpy_chk, and we were able to fold it into // a __memcpy_chk, we still need to return the correct end pointer. if (Ret && Func == LibFunc_stpcpy_chk) - return B.CreateGEP(B.getInt8Ty(), Dst, ConstantInt::get(SizeTTy, Len - 1)); + return B.CreateInBoundsGEP(B.getInt8Ty(), Dst, + ConstantInt::get(SizeTTy, Len - 1)); return copyFlags(*CI, cast<CallInst>(Ret)); } |
