diff options
Diffstat (limited to 'llvm/lib/CodeGen/ExpandMemCmp.cpp')
-rw-r--r-- | llvm/lib/CodeGen/ExpandMemCmp.cpp | 185 |
1 files changed, 85 insertions, 100 deletions
diff --git a/llvm/lib/CodeGen/ExpandMemCmp.cpp b/llvm/lib/CodeGen/ExpandMemCmp.cpp index a1adf4ef9820c..9f85db9de8848 100644 --- a/llvm/lib/CodeGen/ExpandMemCmp.cpp +++ b/llvm/lib/CodeGen/ExpandMemCmp.cpp @@ -23,7 +23,9 @@ #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/InitializePasses.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SizeOpts.h" +#include "llvm/Target/TargetMachine.h" using namespace llvm; @@ -76,7 +78,7 @@ class MemCmpExpansion { IRBuilder<> Builder; // Represents the decomposition in blocks of the expansion. For example, // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and - // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}. + // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {1, 32}. struct LoadEntry { LoadEntry(unsigned LoadSize, uint64_t Offset) : LoadSize(LoadSize), Offset(Offset) { @@ -103,8 +105,12 @@ class MemCmpExpansion { Value *getMemCmpExpansionZeroCase(); Value *getMemCmpEqZeroOneBlock(); Value *getMemCmpOneBlock(); - Value *getPtrToElementAtOffset(Value *Source, Type *LoadSizeType, - uint64_t OffsetBytes); + struct LoadPair { + Value *Lhs = nullptr; + Value *Rhs = nullptr; + }; + LoadPair getLoadPair(Type *LoadSizeType, bool NeedsBSwap, Type *CmpSizeType, + unsigned OffsetBytes); static LoadEntryVector computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes, @@ -261,18 +267,56 @@ void MemCmpExpansion::createResultBlock() { EndBlock->getParent(), EndBlock); } -/// Return a pointer to an element of type `LoadSizeType` at offset -/// `OffsetBytes`. -Value *MemCmpExpansion::getPtrToElementAtOffset(Value *Source, - Type *LoadSizeType, - uint64_t OffsetBytes) { +MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType, + bool NeedsBSwap, + Type *CmpSizeType, + unsigned OffsetBytes) { + // Get the memory source at offset `OffsetBytes`. + Value *LhsSource = CI->getArgOperand(0); + Value *RhsSource = CI->getArgOperand(1); + Align LhsAlign = LhsSource->getPointerAlignment(DL); + Align RhsAlign = RhsSource->getPointerAlignment(DL); if (OffsetBytes > 0) { auto *ByteType = Type::getInt8Ty(CI->getContext()); - Source = Builder.CreateConstGEP1_64( - ByteType, Builder.CreateBitCast(Source, ByteType->getPointerTo()), + LhsSource = Builder.CreateConstGEP1_64( + ByteType, Builder.CreateBitCast(LhsSource, ByteType->getPointerTo()), + OffsetBytes); + RhsSource = Builder.CreateConstGEP1_64( + ByteType, Builder.CreateBitCast(RhsSource, ByteType->getPointerTo()), OffsetBytes); + LhsAlign = commonAlignment(LhsAlign, OffsetBytes); + RhsAlign = commonAlignment(RhsAlign, OffsetBytes); + } + LhsSource = Builder.CreateBitCast(LhsSource, LoadSizeType->getPointerTo()); + RhsSource = Builder.CreateBitCast(RhsSource, LoadSizeType->getPointerTo()); + + // Create a constant or a load from the source. + Value *Lhs = nullptr; + if (auto *C = dyn_cast<Constant>(LhsSource)) + Lhs = ConstantFoldLoadFromConstPtr(C, LoadSizeType, DL); + if (!Lhs) + Lhs = Builder.CreateAlignedLoad(LoadSizeType, LhsSource, LhsAlign); + + Value *Rhs = nullptr; + if (auto *C = dyn_cast<Constant>(RhsSource)) + Rhs = ConstantFoldLoadFromConstPtr(C, LoadSizeType, DL); + if (!Rhs) + Rhs = Builder.CreateAlignedLoad(LoadSizeType, RhsSource, RhsAlign); + + // Swap bytes if required. + if (NeedsBSwap) { + Function *Bswap = Intrinsic::getDeclaration(CI->getModule(), + Intrinsic::bswap, LoadSizeType); + Lhs = Builder.CreateCall(Bswap, Lhs); + Rhs = Builder.CreateCall(Bswap, Rhs); + } + + // Zero extend if required. + if (CmpSizeType != nullptr && CmpSizeType != LoadSizeType) { + Lhs = Builder.CreateZExt(Lhs, CmpSizeType); + Rhs = Builder.CreateZExt(Rhs, CmpSizeType); } - return Builder.CreateBitCast(Source, LoadSizeType->getPointerTo()); + return {Lhs, Rhs}; } // This function creates the IR instructions for loading and comparing 1 byte. @@ -282,18 +326,10 @@ Value *MemCmpExpansion::getPtrToElementAtOffset(Value *Source, void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex, unsigned OffsetBytes) { Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); - Type *LoadSizeType = Type::getInt8Ty(CI->getContext()); - Value *Source1 = - getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, OffsetBytes); - Value *Source2 = - getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, OffsetBytes); - - Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); - Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); - - LoadSrc1 = Builder.CreateZExt(LoadSrc1, Type::getInt32Ty(CI->getContext())); - LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext())); - Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2); + const LoadPair Loads = + getLoadPair(Type::getInt8Ty(CI->getContext()), /*NeedsBSwap=*/false, + Type::getInt32Ty(CI->getContext()), OffsetBytes); + Value *Diff = Builder.CreateSub(Loads.Lhs, Loads.Rhs); PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]); @@ -340,41 +376,19 @@ Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex, : IntegerType::get(CI->getContext(), MaxLoadSize * 8); for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) { const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex]; - - IntegerType *LoadSizeType = - IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); - - Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, - CurLoadEntry.Offset); - Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, - CurLoadEntry.Offset); - - // Get a constant or load a value for each source address. - Value *LoadSrc1 = nullptr; - if (auto *Source1C = dyn_cast<Constant>(Source1)) - LoadSrc1 = ConstantFoldLoadFromConstPtr(Source1C, LoadSizeType, DL); - if (!LoadSrc1) - LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); - - Value *LoadSrc2 = nullptr; - if (auto *Source2C = dyn_cast<Constant>(Source2)) - LoadSrc2 = ConstantFoldLoadFromConstPtr(Source2C, LoadSizeType, DL); - if (!LoadSrc2) - LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); + const LoadPair Loads = getLoadPair( + IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8), + /*NeedsBSwap=*/false, MaxLoadType, CurLoadEntry.Offset); if (NumLoads != 1) { - if (LoadSizeType != MaxLoadType) { - LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType); - LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType); - } // If we have multiple loads per block, we need to generate a composite // comparison using xor+or. - Diff = Builder.CreateXor(LoadSrc1, LoadSrc2); + Diff = Builder.CreateXor(Loads.Lhs, Loads.Rhs); Diff = Builder.CreateZExt(Diff, MaxLoadType); XorList.push_back(Diff); } else { // If there's only one load per block, we just compare the loaded values. - Cmp = Builder.CreateICmpNE(LoadSrc1, LoadSrc2); + Cmp = Builder.CreateICmpNE(Loads.Lhs, Loads.Rhs); } } @@ -451,35 +465,18 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) { Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); - Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, - CurLoadEntry.Offset); - Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, - CurLoadEntry.Offset); - - // Load LoadSizeType from the base address. - Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); - Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); - - if (DL.isLittleEndian()) { - Function *Bswap = Intrinsic::getDeclaration(CI->getModule(), - Intrinsic::bswap, LoadSizeType); - LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1); - LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2); - } - - if (LoadSizeType != MaxLoadType) { - LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType); - LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType); - } + const LoadPair Loads = + getLoadPair(LoadSizeType, /*NeedsBSwap=*/DL.isLittleEndian(), MaxLoadType, + CurLoadEntry.Offset); // Add the loaded values to the phi nodes for calculating memcmp result only // if result is not used in a zero equality. if (!IsUsedForZeroCmp) { - ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]); - ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]); + ResBlock.PhiSrc1->addIncoming(Loads.Lhs, LoadCmpBlocks[BlockIndex]); + ResBlock.PhiSrc2->addIncoming(Loads.Rhs, LoadCmpBlocks[BlockIndex]); } - Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2); + Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Loads.Lhs, Loads.Rhs); BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) ? EndBlock : LoadCmpBlocks[BlockIndex + 1]; @@ -568,42 +565,27 @@ Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() { /// the compare, branch, and phi IR that is required in the general case. Value *MemCmpExpansion::getMemCmpOneBlock() { Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8); - Value *Source1 = CI->getArgOperand(0); - Value *Source2 = CI->getArgOperand(1); - - // Cast source to LoadSizeType*. - if (Source1->getType() != LoadSizeType) - Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); - if (Source2->getType() != LoadSizeType) - Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); - - // Load LoadSizeType from the base address. - Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); - Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); - - if (DL.isLittleEndian() && Size != 1) { - Function *Bswap = Intrinsic::getDeclaration(CI->getModule(), - Intrinsic::bswap, LoadSizeType); - LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1); - LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2); - } + bool NeedsBSwap = DL.isLittleEndian() && Size != 1; + // The i8 and i16 cases don't need compares. We zext the loaded values and + // subtract them to get the suitable negative, zero, or positive i32 result. if (Size < 4) { - // The i8 and i16 cases don't need compares. We zext the loaded values and - // subtract them to get the suitable negative, zero, or positive i32 result. - LoadSrc1 = Builder.CreateZExt(LoadSrc1, Builder.getInt32Ty()); - LoadSrc2 = Builder.CreateZExt(LoadSrc2, Builder.getInt32Ty()); - return Builder.CreateSub(LoadSrc1, LoadSrc2); + const LoadPair Loads = + getLoadPair(LoadSizeType, NeedsBSwap, Builder.getInt32Ty(), + /*Offset*/ 0); + return Builder.CreateSub(Loads.Lhs, Loads.Rhs); } + const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, LoadSizeType, + /*Offset*/ 0); // The result of memcmp is negative, zero, or positive, so produce that by // subtracting 2 extended compare bits: sub (ugt, ult). // If a target prefers to use selects to get -1/0/1, they should be able // to transform this later. The inverse transform (going from selects to math) // may not be possible in the DAG because the selects got converted into // branches before we got there. - Value *CmpUGT = Builder.CreateICmpUGT(LoadSrc1, LoadSrc2); - Value *CmpULT = Builder.CreateICmpULT(LoadSrc1, LoadSrc2); + Value *CmpUGT = Builder.CreateICmpUGT(Loads.Lhs, Loads.Rhs); + Value *CmpULT = Builder.CreateICmpULT(Loads.Lhs, Loads.Rhs); Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty()); Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty()); return Builder.CreateSub(ZextUGT, ZextULT); @@ -843,7 +825,7 @@ bool ExpandMemCmpPass::runOnBlock( continue; } LibFunc Func; - if (TLI->getLibFunc(ImmutableCallSite(CI), Func) && + if (TLI->getLibFunc(*CI, Func) && (Func == LibFunc_memcmp || Func == LibFunc_bcmp) && expandMemCmp(CI, TTI, TL, &DL, PSI, BFI)) { return true; @@ -869,6 +851,9 @@ PreservedAnalyses ExpandMemCmpPass::runImpl( ++BBIt; } } + if (MadeChanges) + for (BasicBlock &BB : F) + SimplifyInstructionsInBlock(&BB); return MadeChanges ? PreservedAnalyses::none() : PreservedAnalyses::all(); } |