diff options
Diffstat (limited to 'llvm/lib/CodeGen/ExpandMemCmp.cpp')
| -rw-r--r-- | llvm/lib/CodeGen/ExpandMemCmp.cpp | 148 |
1 files changed, 121 insertions, 27 deletions
diff --git a/llvm/lib/CodeGen/ExpandMemCmp.cpp b/llvm/lib/CodeGen/ExpandMemCmp.cpp index 500f31bd8e89..e6ca14096249 100644 --- a/llvm/lib/CodeGen/ExpandMemCmp.cpp +++ b/llvm/lib/CodeGen/ExpandMemCmp.cpp @@ -23,6 +23,7 @@ #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -31,6 +32,7 @@ #include <optional> using namespace llvm; +using namespace llvm::PatternMatch; namespace llvm { class TargetLowering; @@ -117,8 +119,8 @@ class MemCmpExpansion { Value *Lhs = nullptr; Value *Rhs = nullptr; }; - LoadPair getLoadPair(Type *LoadSizeType, bool NeedsBSwap, Type *CmpSizeType, - unsigned OffsetBytes); + LoadPair getLoadPair(Type *LoadSizeType, Type *BSwapSizeType, + Type *CmpSizeType, unsigned OffsetBytes); static LoadEntryVector computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes, @@ -128,6 +130,11 @@ class MemCmpExpansion { unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte); + static void optimiseLoadSequence( + LoadEntryVector &LoadSequence, + const TargetTransformInfo::MemCmpExpansionOptions &Options, + bool IsUsedForZeroCmp); + public: MemCmpExpansion(CallInst *CI, uint64_t Size, const TargetTransformInfo::MemCmpExpansionOptions &Options, @@ -210,6 +217,37 @@ MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size, return LoadSequence; } +void MemCmpExpansion::optimiseLoadSequence( + LoadEntryVector &LoadSequence, + const TargetTransformInfo::MemCmpExpansionOptions &Options, + bool IsUsedForZeroCmp) { + // This part of code attempts to optimize the LoadSequence by merging allowed + // subsequences into single loads of allowed sizes from + // `MemCmpExpansionOptions::AllowedTailExpansions`. If it is for zero + // comparison or if no allowed tail expansions are specified, we exit early. + if (IsUsedForZeroCmp || Options.AllowedTailExpansions.empty()) + return; + + while (LoadSequence.size() >= 2) { + auto Last = LoadSequence[LoadSequence.size() - 1]; + auto PreLast = LoadSequence[LoadSequence.size() - 2]; + + // Exit the loop if the two sequences are not contiguous + if (PreLast.Offset + PreLast.LoadSize != Last.Offset) + break; + + auto LoadSize = Last.LoadSize + PreLast.LoadSize; + if (find(Options.AllowedTailExpansions, LoadSize) == + Options.AllowedTailExpansions.end()) + break; + + // Remove the last two sequences and replace with the combined sequence + LoadSequence.pop_back(); + LoadSequence.pop_back(); + LoadSequence.emplace_back(PreLast.Offset, LoadSize); + } +} + // Initialize the basic block structure required for expansion of memcmp call // with given maximum load size and memcmp size parameter. // This structure includes: @@ -255,6 +293,7 @@ MemCmpExpansion::MemCmpExpansion( } } assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant"); + optimiseLoadSequence(LoadSequence, Options, IsUsedForZeroCmp); } unsigned MemCmpExpansion::getNumBlocks() { @@ -278,7 +317,7 @@ void MemCmpExpansion::createResultBlock() { } MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType, - bool NeedsBSwap, + Type *BSwapSizeType, Type *CmpSizeType, unsigned OffsetBytes) { // Get the memory source at offset `OffsetBytes`. @@ -307,16 +346,22 @@ MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType, if (!Rhs) Rhs = Builder.CreateAlignedLoad(LoadSizeType, RhsSource, RhsAlign); + // Zero extend if Byte Swap intrinsic has different type + if (BSwapSizeType && LoadSizeType != BSwapSizeType) { + Lhs = Builder.CreateZExt(Lhs, BSwapSizeType); + Rhs = Builder.CreateZExt(Rhs, BSwapSizeType); + } + // Swap bytes if required. - if (NeedsBSwap) { - Function *Bswap = Intrinsic::getDeclaration(CI->getModule(), - Intrinsic::bswap, LoadSizeType); + if (BSwapSizeType) { + Function *Bswap = Intrinsic::getDeclaration( + CI->getModule(), Intrinsic::bswap, BSwapSizeType); Lhs = Builder.CreateCall(Bswap, Lhs); Rhs = Builder.CreateCall(Bswap, Rhs); } // Zero extend if required. - if (CmpSizeType != nullptr && CmpSizeType != LoadSizeType) { + if (CmpSizeType != nullptr && CmpSizeType != Lhs->getType()) { Lhs = Builder.CreateZExt(Lhs, CmpSizeType); Rhs = Builder.CreateZExt(Rhs, CmpSizeType); } @@ -332,7 +377,7 @@ void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex, BasicBlock *BB = LoadCmpBlocks[BlockIndex]; Builder.SetInsertPoint(BB); const LoadPair Loads = - getLoadPair(Type::getInt8Ty(CI->getContext()), /*NeedsBSwap=*/false, + getLoadPair(Type::getInt8Ty(CI->getContext()), nullptr, Type::getInt32Ty(CI->getContext()), OffsetBytes); Value *Diff = Builder.CreateSub(Loads.Lhs, Loads.Rhs); @@ -385,11 +430,12 @@ Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex, IntegerType *const MaxLoadType = NumLoads == 1 ? nullptr : IntegerType::get(CI->getContext(), MaxLoadSize * 8); + for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) { const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex]; const LoadPair Loads = getLoadPair( - IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8), - /*NeedsBSwap=*/false, MaxLoadType, CurLoadEntry.Offset); + IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8), nullptr, + MaxLoadType, CurLoadEntry.Offset); if (NumLoads != 1) { // If we have multiple loads per block, we need to generate a composite @@ -475,14 +521,20 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) { Type *LoadSizeType = IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); - Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); + Type *BSwapSizeType = + DL.isLittleEndian() + ? IntegerType::get(CI->getContext(), + PowerOf2Ceil(CurLoadEntry.LoadSize * 8)) + : nullptr; + Type *MaxLoadType = IntegerType::get( + CI->getContext(), + std::max(MaxLoadSize, (unsigned)PowerOf2Ceil(CurLoadEntry.LoadSize)) * 8); assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type"); Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); - const LoadPair Loads = - getLoadPair(LoadSizeType, /*NeedsBSwap=*/DL.isLittleEndian(), MaxLoadType, - CurLoadEntry.Offset); + const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType, MaxLoadType, + CurLoadEntry.Offset); // Add the loaded values to the phi nodes for calculating memcmp result only // if result is not used in a zero equality. @@ -558,7 +610,7 @@ void MemCmpExpansion::setupResultBlockPHINodes() { } void MemCmpExpansion::setupEndBlockPHINodes() { - Builder.SetInsertPoint(&EndBlock->front()); + Builder.SetInsertPoint(EndBlock, EndBlock->begin()); PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res"); } @@ -586,21 +638,63 @@ Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() { /// A memcmp expansion that only has one block of load and compare can bypass /// the compare, branch, and phi IR that is required in the general case. +/// This function also analyses users of memcmp, and if there is only one user +/// from which we can conclude that only 2 out of 3 memcmp outcomes really +/// matter, then it generates more efficient code with only one comparison. Value *MemCmpExpansion::getMemCmpOneBlock() { - Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8); bool NeedsBSwap = DL.isLittleEndian() && Size != 1; + Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8); + Type *BSwapSizeType = + NeedsBSwap ? IntegerType::get(CI->getContext(), PowerOf2Ceil(Size * 8)) + : nullptr; + Type *MaxLoadType = + IntegerType::get(CI->getContext(), + std::max(MaxLoadSize, (unsigned)PowerOf2Ceil(Size)) * 8); // The i8 and i16 cases don't need compares. We zext the loaded values and // subtract them to get the suitable negative, zero, or positive i32 result. - if (Size < 4) { - const LoadPair Loads = - getLoadPair(LoadSizeType, NeedsBSwap, Builder.getInt32Ty(), - /*Offset*/ 0); + if (Size == 1 || Size == 2) { + const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType, + Builder.getInt32Ty(), /*Offset*/ 0); return Builder.CreateSub(Loads.Lhs, Loads.Rhs); } - const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, LoadSizeType, + const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType, MaxLoadType, /*Offset*/ 0); + + // If a user of memcmp cares only about two outcomes, for example: + // bool result = memcmp(a, b, NBYTES) > 0; + // We can generate more optimal code with a smaller number of operations + if (CI->hasOneUser()) { + auto *UI = cast<Instruction>(*CI->user_begin()); + ICmpInst::Predicate Pred = ICmpInst::Predicate::BAD_ICMP_PREDICATE; + uint64_t Shift; + bool NeedsZExt = false; + // This is a special case because instead of checking if the result is less + // than zero: + // bool result = memcmp(a, b, NBYTES) < 0; + // Compiler is clever enough to generate the following code: + // bool result = memcmp(a, b, NBYTES) >> 31; + if (match(UI, m_LShr(m_Value(), m_ConstantInt(Shift))) && + Shift == (CI->getType()->getIntegerBitWidth() - 1)) { + Pred = ICmpInst::ICMP_SLT; + NeedsZExt = true; + } else { + // In case of a successful match this call will set `Pred` variable + match(UI, m_ICmp(Pred, m_Specific(CI), m_Zero())); + } + // Generate new code and remove the original memcmp call and the user + if (ICmpInst::isSigned(Pred)) { + Value *Cmp = Builder.CreateICmp(CmpInst::getUnsignedPredicate(Pred), + Loads.Lhs, Loads.Rhs); + auto *Result = NeedsZExt ? Builder.CreateZExt(Cmp, UI->getType()) : Cmp; + UI->replaceAllUsesWith(Result); + UI->eraseFromParent(); + CI->eraseFromParent(); + return nullptr; + } + } + // The result of memcmp is negative, zero, or positive, so produce that by // subtracting 2 extended compare bits: sub (ugt, ult). // If a target prefers to use selects to get -1/0/1, they should be able @@ -615,7 +709,7 @@ Value *MemCmpExpansion::getMemCmpOneBlock() { } // This function expands the memcmp call into an inline expansion and returns -// the memcmp result. +// the memcmp result. Returns nullptr if the memcmp is already replaced. Value *MemCmpExpansion::getMemCmpExpansion() { // Create the basic block framework for a multi-block expansion. if (getNumBlocks() != 1) { @@ -783,11 +877,11 @@ static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI, NumMemCmpInlined++; - Value *Res = Expansion.getMemCmpExpansion(); - - // Replace call with result of expansion and erase call. - CI->replaceAllUsesWith(Res); - CI->eraseFromParent(); + if (Value *Res = Expansion.getMemCmpExpansion()) { + // Replace call with result of expansion and erase call. + CI->replaceAllUsesWith(Res); + CI->eraseFromParent(); + } return true; } |
