diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2019-01-19 10:01:25 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2019-01-19 10:01:25 +0000 |
commit | d8e91e46262bc44006913e6796843909f1ac7bcd (patch) | |
tree | 7d0c143d9b38190e0fa0180805389da22cd834c5 /lib/Transforms/Utils/SimplifyLibCalls.cpp | |
parent | b7eb8e35e481a74962664b63dfb09483b200209a (diff) |
Notes
Diffstat (limited to 'lib/Transforms/Utils/SimplifyLibCalls.cpp')
-rw-r--r-- | lib/Transforms/Utils/SimplifyLibCalls.cpp | 607 |
1 files changed, 397 insertions, 210 deletions
diff --git a/lib/Transforms/Utils/SimplifyLibCalls.cpp b/lib/Transforms/Utils/SimplifyLibCalls.cpp index 15e035874002..1bb26caa2af2 100644 --- a/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/SimplifyLibCalls.h" +#include "llvm/ADT/APSInt.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/Triple.h" @@ -22,6 +23,7 @@ #include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/Loads.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -150,6 +152,32 @@ static bool isLocallyOpenedFile(Value *File, CallInst *CI, IRBuilder<> &B, return true; } +static bool isOnlyUsedInComparisonWithZero(Value *V) { + for (User *U : V->users()) { + if (ICmpInst *IC = dyn_cast<ICmpInst>(U)) + if (Constant *C = dyn_cast<Constant>(IC->getOperand(1))) + if (C->isNullValue()) + continue; + // Unknown instruction. + return false; + } + return true; +} + +static bool canTransformToMemCmp(CallInst *CI, Value *Str, uint64_t Len, + const DataLayout &DL) { + if (!isOnlyUsedInComparisonWithZero(CI)) + return false; + + if (!isDereferenceableAndAlignedPointer(Str, 1, APInt(64, Len), DL)) + return false; + + if (CI->getFunction()->hasFnAttribute(Attribute::SanitizeMemory)) + return false; + + return true; +} + //===----------------------------------------------------------------------===// // String and Memory Library Call Optimizations //===----------------------------------------------------------------------===// @@ -322,6 +350,21 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilder<> &B) { B, DL, TLI); } + // strcmp to memcmp + if (!HasStr1 && HasStr2) { + if (canTransformToMemCmp(CI, Str1P, Len2, DL)) + return emitMemCmp( + Str1P, Str2P, + ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len2), B, DL, + TLI); + } else if (HasStr1 && !HasStr2) { + if (canTransformToMemCmp(CI, Str2P, Len1, DL)) + return emitMemCmp( + Str1P, Str2P, + ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len1), B, DL, + TLI); + } + return nullptr; } @@ -361,6 +404,26 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilder<> &B) { if (HasStr2 && Str2.empty()) // strncmp(x, "", n) -> *x return B.CreateZExt(B.CreateLoad(Str1P, "strcmpload"), CI->getType()); + uint64_t Len1 = GetStringLength(Str1P); + uint64_t Len2 = GetStringLength(Str2P); + + // strncmp to memcmp + if (!HasStr1 && HasStr2) { + Len2 = std::min(Len2, Length); + if (canTransformToMemCmp(CI, Str1P, Len2, DL)) + return emitMemCmp( + Str1P, Str2P, + ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len2), B, DL, + TLI); + } else if (HasStr1 && !HasStr2) { + Len1 = std::min(Len1, Length); + if (canTransformToMemCmp(CI, Str2P, Len1, DL)) + return emitMemCmp( + Str1P, Str2P, + ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len1), B, DL, + TLI); + } + return nullptr; } @@ -735,8 +798,11 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilder<> &B) { Bitfield.setBit((unsigned char)C); Value *BitfieldC = B.getInt(Bitfield); - // First check that the bit field access is within bounds. + // Adjust width of "C" to the bitfield width, then mask off the high bits. Value *C = B.CreateZExtOrTrunc(CI->getArgOperand(1), BitfieldC->getType()); + C = B.CreateAnd(C, B.getIntN(Width, 0xFF)); + + // First check that the bit field access is within bounds. Value *Bounds = B.CreateICmp(ICmpInst::ICMP_ULT, C, B.getIntN(Width, Width), "memchr.bounds"); @@ -860,8 +926,7 @@ Value *LibCallSimplifier::optimizeMemMove(CallInst *CI, IRBuilder<> &B) { } /// Fold memset[_chk](malloc(n), 0, n) --> calloc(1, n). -static Value *foldMallocMemset(CallInst *Memset, IRBuilder<> &B, - const TargetLibraryInfo &TLI) { +Value *LibCallSimplifier::foldMallocMemset(CallInst *Memset, IRBuilder<> &B) { // This has to be a memset of zeros (bzero). auto *FillValue = dyn_cast<ConstantInt>(Memset->getArgOperand(1)); if (!FillValue || FillValue->getZExtValue() != 0) @@ -881,7 +946,7 @@ static Value *foldMallocMemset(CallInst *Memset, IRBuilder<> &B, return nullptr; LibFunc Func; - if (!TLI.getLibFunc(*InnerCallee, Func) || !TLI.has(Func) || + if (!TLI->getLibFunc(*InnerCallee, Func) || !TLI->has(Func) || Func != LibFunc_malloc) return nullptr; @@ -896,18 +961,18 @@ static Value *foldMallocMemset(CallInst *Memset, IRBuilder<> &B, IntegerType *SizeType = DL.getIntPtrType(B.GetInsertBlock()->getContext()); Value *Calloc = emitCalloc(ConstantInt::get(SizeType, 1), Malloc->getArgOperand(0), Malloc->getAttributes(), - B, TLI); + B, *TLI); if (!Calloc) return nullptr; Malloc->replaceAllUsesWith(Calloc); - Malloc->eraseFromParent(); + eraseFromParent(Malloc); return Calloc; } Value *LibCallSimplifier::optimizeMemSet(CallInst *CI, IRBuilder<> &B) { - if (auto *Calloc = foldMallocMemset(CI, B, *TLI)) + if (auto *Calloc = foldMallocMemset(CI, B)) return Calloc; // memset(p, v, n) -> llvm.memset(align 1 p, v, n) @@ -927,6 +992,20 @@ Value *LibCallSimplifier::optimizeRealloc(CallInst *CI, IRBuilder<> &B) { // Math Library Optimizations //===----------------------------------------------------------------------===// +// Replace a libcall \p CI with a call to intrinsic \p IID +static Value *replaceUnaryCall(CallInst *CI, IRBuilder<> &B, Intrinsic::ID IID) { + // Propagate fast-math flags from the existing call to the new call. + IRBuilder<>::FastMathFlagGuard Guard(B); + B.setFastMathFlags(CI->getFastMathFlags()); + + Module *M = CI->getModule(); + Value *V = CI->getArgOperand(0); + Function *F = Intrinsic::getDeclaration(M, IID, CI->getType()); + CallInst *NewCall = B.CreateCall(F, V); + NewCall->takeName(CI); + return NewCall; +} + /// Return a variant of Val with float type. /// Currently this works in two cases: If Val is an FPExtension of a float /// value to something bigger, simply return the operand. @@ -949,104 +1028,75 @@ static Value *valueHasFloatPrecision(Value *Val) { return nullptr; } -/// Shrink double -> float for unary functions like 'floor'. -static Value *optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B, - bool CheckRetType) { - Function *Callee = CI->getCalledFunction(); - // We know this libcall has a valid prototype, but we don't know which. +/// Shrink double -> float functions. +static Value *optimizeDoubleFP(CallInst *CI, IRBuilder<> &B, + bool isBinary, bool isPrecise = false) { if (!CI->getType()->isDoubleTy()) return nullptr; - if (CheckRetType) { - // Check if all the uses for function like 'sin' are converted to float. + // If not all the uses of the function are converted to float, then bail out. + // This matters if the precision of the result is more important than the + // precision of the arguments. + if (isPrecise) for (User *U : CI->users()) { FPTruncInst *Cast = dyn_cast<FPTruncInst>(U); if (!Cast || !Cast->getType()->isFloatTy()) return nullptr; } - } - // If this is something like 'floor((double)floatval)', convert to floorf. - Value *V = valueHasFloatPrecision(CI->getArgOperand(0)); - if (V == nullptr) + // If this is something like 'g((double) float)', convert to 'gf(float)'. + Value *V[2]; + V[0] = valueHasFloatPrecision(CI->getArgOperand(0)); + V[1] = isBinary ? valueHasFloatPrecision(CI->getArgOperand(1)) : nullptr; + if (!V[0] || (isBinary && !V[1])) return nullptr; // If call isn't an intrinsic, check that it isn't within a function with the - // same name as the float version of this call. + // same name as the float version of this call, otherwise the result is an + // infinite loop. For example, from MinGW-w64: // - // e.g. inline float expf(float val) { return (float) exp((double) val); } - // - // A similar such definition exists in the MinGW-w64 math.h header file which - // when compiled with -O2 -ffast-math causes the generation of infinite loops - // where expf is called. - if (!Callee->isIntrinsic()) { - const Function *F = CI->getFunction(); - StringRef FName = F->getName(); - StringRef CalleeName = Callee->getName(); - if ((FName.size() == (CalleeName.size() + 1)) && - (FName.back() == 'f') && - FName.startswith(CalleeName)) + // float expf(float val) { return (float) exp((double) val); } + Function *CalleeFn = CI->getCalledFunction(); + StringRef CalleeNm = CalleeFn->getName(); + AttributeList CalleeAt = CalleeFn->getAttributes(); + if (CalleeFn && !CalleeFn->isIntrinsic()) { + const Function *Fn = CI->getFunction(); + StringRef FnName = Fn->getName(); + if (FnName.back() == 'f' && + FnName.size() == (CalleeNm.size() + 1) && + FnName.startswith(CalleeNm)) return nullptr; } - // Propagate fast-math flags from the existing call to the new call. + // Propagate the math semantics from the current function to the new function. IRBuilder<>::FastMathFlagGuard Guard(B); B.setFastMathFlags(CI->getFastMathFlags()); - // floor((double)floatval) -> (double)floorf(floatval) - if (Callee->isIntrinsic()) { + // g((double) float) -> (double) gf(float) + Value *R; + if (CalleeFn->isIntrinsic()) { Module *M = CI->getModule(); - Intrinsic::ID IID = Callee->getIntrinsicID(); - Function *F = Intrinsic::getDeclaration(M, IID, B.getFloatTy()); - V = B.CreateCall(F, V); - } else { - // The call is a library call rather than an intrinsic. - V = emitUnaryFloatFnCall(V, Callee->getName(), B, Callee->getAttributes()); + Intrinsic::ID IID = CalleeFn->getIntrinsicID(); + Function *Fn = Intrinsic::getDeclaration(M, IID, B.getFloatTy()); + R = isBinary ? B.CreateCall(Fn, V) : B.CreateCall(Fn, V[0]); } + else + R = isBinary ? emitBinaryFloatFnCall(V[0], V[1], CalleeNm, B, CalleeAt) + : emitUnaryFloatFnCall(V[0], CalleeNm, B, CalleeAt); - return B.CreateFPExt(V, B.getDoubleTy()); + return B.CreateFPExt(R, B.getDoubleTy()); } -// Replace a libcall \p CI with a call to intrinsic \p IID -static Value *replaceUnaryCall(CallInst *CI, IRBuilder<> &B, Intrinsic::ID IID) { - // Propagate fast-math flags from the existing call to the new call. - IRBuilder<>::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - - Module *M = CI->getModule(); - Value *V = CI->getArgOperand(0); - Function *F = Intrinsic::getDeclaration(M, IID, CI->getType()); - CallInst *NewCall = B.CreateCall(F, V); - NewCall->takeName(CI); - return NewCall; +/// Shrink double -> float for unary functions. +static Value *optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B, + bool isPrecise = false) { + return optimizeDoubleFP(CI, B, false, isPrecise); } -/// Shrink double -> float for binary functions like 'fmin/fmax'. -static Value *optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - // We know this libcall has a valid prototype, but we don't know which. - if (!CI->getType()->isDoubleTy()) - return nullptr; - - // If this is something like 'fmin((double)floatval1, (double)floatval2)', - // or fmin(1.0, (double)floatval), then we convert it to fminf. - Value *V1 = valueHasFloatPrecision(CI->getArgOperand(0)); - if (V1 == nullptr) - return nullptr; - Value *V2 = valueHasFloatPrecision(CI->getArgOperand(1)); - if (V2 == nullptr) - return nullptr; - - // Propagate fast-math flags from the existing call to the new call. - IRBuilder<>::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - - // fmin((double)floatval1, (double)floatval2) - // -> (double)fminf(floatval1, floatval2) - // TODO: Handle intrinsics in the same way as in optimizeUnaryDoubleFP(). - Value *V = emitBinaryFloatFnCall(V1, V2, Callee->getName(), B, - Callee->getAttributes()); - return B.CreateFPExt(V, B.getDoubleTy()); +/// Shrink double -> float for binary functions. +static Value *optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B, + bool isPrecise = false) { + return optimizeDoubleFP(CI, B, true, isPrecise); } // cabs(z) -> sqrt((creal(z)*creal(z)) + (cimag(z)*cimag(z))) @@ -1078,20 +1128,39 @@ Value *LibCallSimplifier::optimizeCAbs(CallInst *CI, IRBuilder<> &B) { return B.CreateCall(FSqrt, B.CreateFAdd(RealReal, ImagImag), "cabs"); } -Value *LibCallSimplifier::optimizeCos(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - Value *Ret = nullptr; - StringRef Name = Callee->getName(); - if (UnsafeFPShrink && Name == "cos" && hasFloatVersion(Name)) - Ret = optimizeUnaryDoubleFP(CI, B, true); - - // cos(-x) -> cos(x) - Value *Op1 = CI->getArgOperand(0); - if (BinaryOperator::isFNeg(Op1)) { - BinaryOperator *BinExpr = cast<BinaryOperator>(Op1); - return B.CreateCall(Callee, BinExpr->getOperand(1), "cos"); +static Value *optimizeTrigReflections(CallInst *Call, LibFunc Func, + IRBuilder<> &B) { + if (!isa<FPMathOperator>(Call)) + return nullptr; + + IRBuilder<>::FastMathFlagGuard Guard(B); + B.setFastMathFlags(Call->getFastMathFlags()); + + // TODO: Can this be shared to also handle LLVM intrinsics? + Value *X; + switch (Func) { + case LibFunc_sin: + case LibFunc_sinf: + case LibFunc_sinl: + case LibFunc_tan: + case LibFunc_tanf: + case LibFunc_tanl: + // sin(-X) --> -sin(X) + // tan(-X) --> -tan(X) + if (match(Call->getArgOperand(0), m_OneUse(m_FNeg(m_Value(X))))) + return B.CreateFNeg(B.CreateCall(Call->getCalledFunction(), X)); + break; + case LibFunc_cos: + case LibFunc_cosf: + case LibFunc_cosl: + // cos(-X) --> cos(X) + if (match(Call->getArgOperand(0), m_FNeg(m_Value(X)))) + return B.CreateCall(Call->getCalledFunction(), X, "cos"); + break; + default: + break; } - return Ret; + return nullptr; } static Value *getPow(Value *InnerChain[33], unsigned Exp, IRBuilder<> &B) { @@ -1119,37 +1188,175 @@ static Value *getPow(Value *InnerChain[33], unsigned Exp, IRBuilder<> &B) { return InnerChain[Exp]; } -/// Use square root in place of pow(x, +/-0.5). -Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) { - // TODO: There is some subset of 'fast' under which these transforms should - // be allowed. - if (!Pow->isFast()) - return nullptr; - - Value *Sqrt, *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); +/// Use exp{,2}(x * y) for pow(exp{,2}(x), y); +/// exp2(n * x) for pow(2.0 ** n, x); exp10(x) for pow(10.0, x). +Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) { + Value *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); + AttributeList Attrs = Pow->getCalledFunction()->getAttributes(); + Module *Mod = Pow->getModule(); Type *Ty = Pow->getType(); + bool Ignored; - const APFloat *ExpoF; - if (!match(Expo, m_APFloat(ExpoF)) || - (!ExpoF->isExactlyValue(0.5) && !ExpoF->isExactlyValue(-0.5))) + // Evaluate special cases related to a nested function as the base. + + // pow(exp(x), y) -> exp(x * y) + // pow(exp2(x), y) -> exp2(x * y) + // If exp{,2}() is used only once, it is better to fold two transcendental + // math functions into one. If used again, exp{,2}() would still have to be + // called with the original argument, then keep both original transcendental + // functions. However, this transformation is only safe with fully relaxed + // math semantics, since, besides rounding differences, it changes overflow + // and underflow behavior quite dramatically. For example: + // pow(exp(1000), 0.001) = pow(inf, 0.001) = inf + // Whereas: + // exp(1000 * 0.001) = exp(1) + // TODO: Loosen the requirement for fully relaxed math semantics. + // TODO: Handle exp10() when more targets have it available. + CallInst *BaseFn = dyn_cast<CallInst>(Base); + if (BaseFn && BaseFn->hasOneUse() && BaseFn->isFast() && Pow->isFast()) { + LibFunc LibFn; + + Function *CalleeFn = BaseFn->getCalledFunction(); + if (CalleeFn && + TLI->getLibFunc(CalleeFn->getName(), LibFn) && TLI->has(LibFn)) { + StringRef ExpName; + Intrinsic::ID ID; + Value *ExpFn; + LibFunc LibFnFloat; + LibFunc LibFnDouble; + LibFunc LibFnLongDouble; + + switch (LibFn) { + default: + return nullptr; + case LibFunc_expf: case LibFunc_exp: case LibFunc_expl: + ExpName = TLI->getName(LibFunc_exp); + ID = Intrinsic::exp; + LibFnFloat = LibFunc_expf; + LibFnDouble = LibFunc_exp; + LibFnLongDouble = LibFunc_expl; + break; + case LibFunc_exp2f: case LibFunc_exp2: case LibFunc_exp2l: + ExpName = TLI->getName(LibFunc_exp2); + ID = Intrinsic::exp2; + LibFnFloat = LibFunc_exp2f; + LibFnDouble = LibFunc_exp2; + LibFnLongDouble = LibFunc_exp2l; + break; + } + + // Create new exp{,2}() with the product as its argument. + Value *FMul = B.CreateFMul(BaseFn->getArgOperand(0), Expo, "mul"); + ExpFn = BaseFn->doesNotAccessMemory() + ? B.CreateCall(Intrinsic::getDeclaration(Mod, ID, Ty), + FMul, ExpName) + : emitUnaryFloatFnCall(FMul, TLI, LibFnDouble, LibFnFloat, + LibFnLongDouble, B, + BaseFn->getAttributes()); + + // Since the new exp{,2}() is different from the original one, dead code + // elimination cannot be trusted to remove it, since it may have side + // effects (e.g., errno). When the only consumer for the original + // exp{,2}() is pow(), then it has to be explicitly erased. + BaseFn->replaceAllUsesWith(ExpFn); + eraseFromParent(BaseFn); + + return ExpFn; + } + } + + // Evaluate special cases related to a constant base. + + const APFloat *BaseF; + if (!match(Pow->getArgOperand(0), m_APFloat(BaseF))) return nullptr; + // pow(2.0 ** n, x) -> exp2(n * x) + if (hasUnaryFloatFn(TLI, Ty, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l)) { + APFloat BaseR = APFloat(1.0); + BaseR.convert(BaseF->getSemantics(), APFloat::rmTowardZero, &Ignored); + BaseR = BaseR / *BaseF; + bool IsInteger = BaseF->isInteger(), + IsReciprocal = BaseR.isInteger(); + const APFloat *NF = IsReciprocal ? &BaseR : BaseF; + APSInt NI(64, false); + if ((IsInteger || IsReciprocal) && + !NF->convertToInteger(NI, APFloat::rmTowardZero, &Ignored) && + NI > 1 && NI.isPowerOf2()) { + double N = NI.logBase2() * (IsReciprocal ? -1.0 : 1.0); + Value *FMul = B.CreateFMul(Expo, ConstantFP::get(Ty, N), "mul"); + if (Pow->doesNotAccessMemory()) + return B.CreateCall(Intrinsic::getDeclaration(Mod, Intrinsic::exp2, Ty), + FMul, "exp2"); + else + return emitUnaryFloatFnCall(FMul, TLI, LibFunc_exp2, LibFunc_exp2f, + LibFunc_exp2l, B, Attrs); + } + } + + // pow(10.0, x) -> exp10(x) + // TODO: There is no exp10() intrinsic yet, but some day there shall be one. + if (match(Base, m_SpecificFP(10.0)) && + hasUnaryFloatFn(TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) + return emitUnaryFloatFnCall(Expo, TLI, LibFunc_exp10, LibFunc_exp10f, + LibFunc_exp10l, B, Attrs); + + return nullptr; +} + +static Value *getSqrtCall(Value *V, AttributeList Attrs, bool NoErrno, + Module *M, IRBuilder<> &B, + const TargetLibraryInfo *TLI) { // If errno is never set, then use the intrinsic for sqrt(). - if (Pow->hasFnAttr(Attribute::ReadNone)) { - Function *SqrtFn = Intrinsic::getDeclaration(Pow->getModule(), - Intrinsic::sqrt, Ty); - Sqrt = B.CreateCall(SqrtFn, Base); + if (NoErrno) { + Function *SqrtFn = + Intrinsic::getDeclaration(M, Intrinsic::sqrt, V->getType()); + return B.CreateCall(SqrtFn, V, "sqrt"); } + // Otherwise, use the libcall for sqrt(). - else if (hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) + if (hasUnaryFloatFn(TLI, V->getType(), LibFunc_sqrt, LibFunc_sqrtf, + LibFunc_sqrtl)) // TODO: We also should check that the target can in fact lower the sqrt() // libcall. We currently have no way to ask this question, so we ask if // the target has a sqrt() libcall, which is not exactly the same. - Sqrt = emitUnaryFloatFnCall(Base, TLI->getName(LibFunc_sqrt), B, - Pow->getCalledFunction()->getAttributes()); - else + return emitUnaryFloatFnCall(V, TLI, LibFunc_sqrt, LibFunc_sqrtf, + LibFunc_sqrtl, B, Attrs); + + return nullptr; +} + +/// Use square root in place of pow(x, +/-0.5). +Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) { + Value *Sqrt, *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); + AttributeList Attrs = Pow->getCalledFunction()->getAttributes(); + Module *Mod = Pow->getModule(); + Type *Ty = Pow->getType(); + + const APFloat *ExpoF; + if (!match(Expo, m_APFloat(ExpoF)) || + (!ExpoF->isExactlyValue(0.5) && !ExpoF->isExactlyValue(-0.5))) return nullptr; + Sqrt = getSqrtCall(Base, Attrs, Pow->doesNotAccessMemory(), Mod, B, TLI); + if (!Sqrt) + return nullptr; + + // Handle signed zero base by expanding to fabs(sqrt(x)). + if (!Pow->hasNoSignedZeros()) { + Function *FAbsFn = Intrinsic::getDeclaration(Mod, Intrinsic::fabs, Ty); + Sqrt = B.CreateCall(FAbsFn, Sqrt, "abs"); + } + + // Handle non finite base by expanding to + // (x == -infinity ? +infinity : sqrt(x)). + if (!Pow->hasNoInfs()) { + Value *PosInf = ConstantFP::getInfinity(Ty), + *NegInf = ConstantFP::getInfinity(Ty, true); + Value *FCmp = B.CreateFCmpOEQ(Base, NegInf, "isinf"); + Sqrt = B.CreateSelect(FCmp, PosInf, Sqrt); + } + // If the exponent is negative, then get the reciprocal. if (ExpoF->isNegative()) Sqrt = B.CreateFDiv(ConstantFP::get(Ty, 1.0), Sqrt, "reciprocal"); @@ -1160,134 +1367,109 @@ Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) { Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) { Value *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); Function *Callee = Pow->getCalledFunction(); - AttributeList Attrs = Callee->getAttributes(); StringRef Name = Callee->getName(); - Module *Module = Pow->getModule(); Type *Ty = Pow->getType(); Value *Shrunk = nullptr; bool Ignored; - if (UnsafeFPShrink && - Name == TLI->getName(LibFunc_pow) && hasFloatVersion(Name)) - Shrunk = optimizeUnaryDoubleFP(Pow, B, true); + // Bail out if simplifying libcalls to pow() is disabled. + if (!hasUnaryFloatFn(TLI, Ty, LibFunc_pow, LibFunc_powf, LibFunc_powl)) + return nullptr; // Propagate the math semantics from the call to any created instructions. IRBuilder<>::FastMathFlagGuard Guard(B); B.setFastMathFlags(Pow->getFastMathFlags()); + // Shrink pow() to powf() if the arguments are single precision, + // unless the result is expected to be double precision. + if (UnsafeFPShrink && + Name == TLI->getName(LibFunc_pow) && hasFloatVersion(Name)) + Shrunk = optimizeBinaryDoubleFP(Pow, B, true); + // Evaluate special cases related to the base. // pow(1.0, x) -> 1.0 - if (match(Base, m_SpecificFP(1.0))) + if (match(Base, m_FPOne())) return Base; - // pow(2.0, x) -> exp2(x) - if (match(Base, m_SpecificFP(2.0))) { - Value *Exp2 = Intrinsic::getDeclaration(Module, Intrinsic::exp2, Ty); - return B.CreateCall(Exp2, Expo, "exp2"); - } - - // pow(10.0, x) -> exp10(x) - if (ConstantFP *BaseC = dyn_cast<ConstantFP>(Base)) - // There's no exp10 intrinsic yet, but, maybe, some day there shall be one. - if (BaseC->isExactlyValue(10.0) && - hasUnaryFloatFn(TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) - return emitUnaryFloatFnCall(Expo, TLI->getName(LibFunc_exp10), B, Attrs); - - // pow(exp(x), y) -> exp(x * y) - // pow(exp2(x), y) -> exp2(x * y) - // We enable these only with fast-math. Besides rounding differences, the - // transformation changes overflow and underflow behavior quite dramatically. - // Example: x = 1000, y = 0.001. - // pow(exp(x), y) = pow(inf, 0.001) = inf, whereas exp(x*y) = exp(1). - auto *BaseFn = dyn_cast<CallInst>(Base); - if (BaseFn && BaseFn->isFast() && Pow->isFast()) { - LibFunc LibFn; - Function *CalleeFn = BaseFn->getCalledFunction(); - if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && - (LibFn == LibFunc_exp || LibFn == LibFunc_exp2) && TLI->has(LibFn)) { - IRBuilder<>::FastMathFlagGuard Guard(B); - B.setFastMathFlags(Pow->getFastMathFlags()); - - Value *FMul = B.CreateFMul(BaseFn->getArgOperand(0), Expo, "mul"); - return emitUnaryFloatFnCall(FMul, CalleeFn->getName(), B, - CalleeFn->getAttributes()); - } - } + if (Value *Exp = replacePowWithExp(Pow, B)) + return Exp; // Evaluate special cases related to the exponent. - if (Value *Sqrt = replacePowWithSqrt(Pow, B)) - return Sqrt; - - ConstantFP *ExpoC = dyn_cast<ConstantFP>(Expo); - if (!ExpoC) - return Shrunk; - // pow(x, -1.0) -> 1.0 / x - if (ExpoC->isExactlyValue(-1.0)) + if (match(Expo, m_SpecificFP(-1.0))) return B.CreateFDiv(ConstantFP::get(Ty, 1.0), Base, "reciprocal"); // pow(x, 0.0) -> 1.0 - if (ExpoC->getValueAPF().isZero()) - return ConstantFP::get(Ty, 1.0); + if (match(Expo, m_SpecificFP(0.0))) + return ConstantFP::get(Ty, 1.0); // pow(x, 1.0) -> x - if (ExpoC->isExactlyValue(1.0)) + if (match(Expo, m_FPOne())) return Base; // pow(x, 2.0) -> x * x - if (ExpoC->isExactlyValue(2.0)) + if (match(Expo, m_SpecificFP(2.0))) return B.CreateFMul(Base, Base, "square"); - // FIXME: Correct the transforms and pull this into replacePowWithSqrt(). - if (ExpoC->isExactlyValue(0.5) && - hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) { - // Expand pow(x, 0.5) to (x == -infinity ? +infinity : fabs(sqrt(x))). - // This is faster than calling pow(), and still handles -0.0 and - // negative infinity correctly. - // TODO: In finite-only mode, this could be just fabs(sqrt(x)). - Value *PosInf = ConstantFP::getInfinity(Ty); - Value *NegInf = ConstantFP::getInfinity(Ty, true); - - // TODO: As above, we should lower to the sqrt() intrinsic if the pow() is - // an intrinsic, to match errno semantics. - Value *Sqrt = emitUnaryFloatFnCall(Base, TLI->getName(LibFunc_sqrt), - B, Attrs); - Function *FAbsFn = Intrinsic::getDeclaration(Module, Intrinsic::fabs, Ty); - Value *FAbs = B.CreateCall(FAbsFn, Sqrt, "abs"); - Value *FCmp = B.CreateFCmpOEQ(Base, NegInf, "isinf"); - Sqrt = B.CreateSelect(FCmp, PosInf, FAbs); + if (Value *Sqrt = replacePowWithSqrt(Pow, B)) return Sqrt; - } - // pow(x, n) -> x * x * x * .... - if (Pow->isFast()) { - APFloat ExpoA = abs(ExpoC->getValueAPF()); - // We limit to a max of 7 fmul(s). Thus the maximum exponent is 32. - // This transformation applies to integer exponents only. - if (!ExpoA.isInteger() || - ExpoA.compare - (APFloat(ExpoA.getSemantics(), 32.0)) == APFloat::cmpGreaterThan) - return nullptr; + // pow(x, n) -> x * x * x * ... + const APFloat *ExpoF; + if (Pow->isFast() && match(Expo, m_APFloat(ExpoF))) { + // We limit to a max of 7 multiplications, thus the maximum exponent is 32. + // If the exponent is an integer+0.5 we generate a call to sqrt and an + // additional fmul. + // TODO: This whole transformation should be backend specific (e.g. some + // backends might prefer libcalls or the limit for the exponent might + // be different) and it should also consider optimizing for size. + APFloat LimF(ExpoF->getSemantics(), 33.0), + ExpoA(abs(*ExpoF)); + if (ExpoA.compare(LimF) == APFloat::cmpLessThan) { + // This transformation applies to integer or integer+0.5 exponents only. + // For integer+0.5, we create a sqrt(Base) call. + Value *Sqrt = nullptr; + if (!ExpoA.isInteger()) { + APFloat Expo2 = ExpoA; + // To check if ExpoA is an integer + 0.5, we add it to itself. If there + // is no floating point exception and the result is an integer, then + // ExpoA == integer + 0.5 + if (Expo2.add(ExpoA, APFloat::rmNearestTiesToEven) != APFloat::opOK) + return nullptr; + + if (!Expo2.isInteger()) + return nullptr; + + Sqrt = + getSqrtCall(Base, Pow->getCalledFunction()->getAttributes(), + Pow->doesNotAccessMemory(), Pow->getModule(), B, TLI); + } - // We will memoize intermediate products of the Addition Chain. - Value *InnerChain[33] = {nullptr}; - InnerChain[1] = Base; - InnerChain[2] = B.CreateFMul(Base, Base, "square"); + // We will memoize intermediate products of the Addition Chain. + Value *InnerChain[33] = {nullptr}; + InnerChain[1] = Base; + InnerChain[2] = B.CreateFMul(Base, Base, "square"); - // We cannot readily convert a non-double type (like float) to a double. - // So we first convert it to something which could be converted to double. - ExpoA.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &Ignored); - Value *FMul = getPow(InnerChain, ExpoA.convertToDouble(), B); + // We cannot readily convert a non-double type (like float) to a double. + // So we first convert it to something which could be converted to double. + ExpoA.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &Ignored); + Value *FMul = getPow(InnerChain, ExpoA.convertToDouble(), B); - // If the exponent is negative, then get the reciprocal. - if (ExpoC->isNegative()) - FMul = B.CreateFDiv(ConstantFP::get(Ty, 1.0), FMul, "reciprocal"); - return FMul; + // Expand pow(x, y+0.5) to pow(x, y) * sqrt(x). + if (Sqrt) + FMul = B.CreateFMul(FMul, Sqrt); + + // If the exponent is negative, then get the reciprocal. + if (ExpoF->isNegative()) + FMul = B.CreateFDiv(ConstantFP::get(Ty, 1.0), FMul, "reciprocal"); + + return FMul; + } } - return nullptr; + return Shrunk; } Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) { @@ -2285,11 +2467,10 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, if (CI->isStrictFP()) return nullptr; + if (Value *V = optimizeTrigReflections(CI, Func, Builder)) + return V; + switch (Func) { - case LibFunc_cosf: - case LibFunc_cos: - case LibFunc_cosl: - return optimizeCos(CI, Builder); case LibFunc_sinpif: case LibFunc_sinpi: case LibFunc_cospif: @@ -2344,6 +2525,7 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_exp: case LibFunc_exp10: case LibFunc_expm1: + case LibFunc_cos: case LibFunc_sin: case LibFunc_sinh: case LibFunc_tanh: @@ -2425,7 +2607,7 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { if (Value *V = optimizeStringMemoryLibCall(SimplifiedCI, TmpBuilder)) { // If we were able to further simplify, remove the now redundant call. SimplifiedCI->replaceAllUsesWith(V); - SimplifiedCI->eraseFromParent(); + eraseFromParent(SimplifiedCI); return V; } } @@ -2504,15 +2686,20 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { LibCallSimplifier::LibCallSimplifier( const DataLayout &DL, const TargetLibraryInfo *TLI, OptimizationRemarkEmitter &ORE, - function_ref<void(Instruction *, Value *)> Replacer) + function_ref<void(Instruction *, Value *)> Replacer, + function_ref<void(Instruction *)> Eraser) : FortifiedSimplifier(TLI), DL(DL), TLI(TLI), ORE(ORE), - UnsafeFPShrink(false), Replacer(Replacer) {} + UnsafeFPShrink(false), Replacer(Replacer), Eraser(Eraser) {} void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) { // Indirect through the replacer used in this instance. Replacer(I, With); } +void LibCallSimplifier::eraseFromParent(Instruction *I) { + Eraser(I); +} + // TODO: // Additional cases that we need to add to this file: // |