diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp | 55 |
1 files changed, 51 insertions, 4 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index 1fd8b88dd776..35adaa3bde65 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -31,6 +31,7 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" +#include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -427,27 +428,73 @@ static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) { return true; } +/// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids +/// pessimistic codegen that has to account for setting errno and can enable +/// vectorization. +static bool +foldSqrt(Instruction &I, TargetTransformInfo &TTI, TargetLibraryInfo &TLI) { + // Match a call to sqrt mathlib function. + auto *Call = dyn_cast<CallInst>(&I); + if (!Call) + return false; + + Module *M = Call->getModule(); + LibFunc Func; + if (!TLI.getLibFunc(*Call, Func) || !isLibFuncEmittable(M, &TLI, Func)) + return false; + + if (Func != LibFunc_sqrt && Func != LibFunc_sqrtf && Func != LibFunc_sqrtl) + return false; + + // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created, + // and (3) we would not end up lowering to a libcall anyway (which could + // change the value of errno), then: + // (1) the operand arg must not be less than -0.0. + // (2) errno won't be set. + // (3) it is safe to convert this to an intrinsic call. + // TODO: Check if the arg is known non-negative. + Type *Ty = Call->getType(); + if (TTI.haveFastSqrt(Ty) && Call->hasNoNaNs()) { + IRBuilder<> Builder(&I); + IRBuilderBase::FastMathFlagGuard Guard(Builder); + Builder.setFastMathFlags(Call->getFastMathFlags()); + + Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty); + Value *NewSqrt = Builder.CreateCall(Sqrt, Call->getArgOperand(0), "sqrt"); + I.replaceAllUsesWith(NewSqrt); + + // Explicitly erase the old call because a call with side effects is not + // trivially dead. + I.eraseFromParent(); + return true; + } + + return false; +} + /// This is the entry point for folds that could be implemented in regular /// InstCombine, but they are separated because they are not expected to /// occur frequently and/or have more than a constant-length pattern match. static bool foldUnusualPatterns(Function &F, DominatorTree &DT, - TargetTransformInfo &TTI) { + TargetTransformInfo &TTI, + TargetLibraryInfo &TLI) { bool MadeChange = false; for (BasicBlock &BB : F) { // Ignore unreachable basic blocks. if (!DT.isReachableFromEntry(&BB)) continue; - // Do not delete instructions under here and invalidate the iterator. + // Walk the block backwards for efficiency. We're matching a chain of // use->defs, so we're more likely to succeed by starting from the bottom. // Also, we want to avoid matching partial patterns. // TODO: It would be more efficient if we removed dead instructions // iteratively in this loop rather than waiting until the end. - for (Instruction &I : llvm::reverse(BB)) { + for (Instruction &I : make_early_inc_range(llvm::reverse(BB))) { MadeChange |= foldAnyOrAllBitsSet(I); MadeChange |= foldGuardedFunnelShift(I, DT); MadeChange |= tryToRecognizePopCount(I); MadeChange |= tryToFPToSat(I, TTI); + MadeChange |= foldSqrt(I, TTI, TLI); } } @@ -467,7 +514,7 @@ static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI, const DataLayout &DL = F.getParent()->getDataLayout(); TruncInstCombine TIC(AC, TLI, DL, DT); MadeChange |= TIC.run(F); - MadeChange |= foldUnusualPatterns(F, DT, TTI); + MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI); return MadeChange; } |