summaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen/ExpandMemCmp.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/CodeGen/ExpandMemCmp.cpp')
-rw-r--r--llvm/lib/CodeGen/ExpandMemCmp.cpp148
1 files changed, 121 insertions, 27 deletions
diff --git a/llvm/lib/CodeGen/ExpandMemCmp.cpp b/llvm/lib/CodeGen/ExpandMemCmp.cpp
index 500f31bd8e89..e6ca14096249 100644
--- a/llvm/lib/CodeGen/ExpandMemCmp.cpp
+++ b/llvm/lib/CodeGen/ExpandMemCmp.cpp
@@ -23,6 +23,7 @@
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
@@ -31,6 +32,7 @@
#include <optional>
using namespace llvm;
+using namespace llvm::PatternMatch;
namespace llvm {
class TargetLowering;
@@ -117,8 +119,8 @@ class MemCmpExpansion {
Value *Lhs = nullptr;
Value *Rhs = nullptr;
};
- LoadPair getLoadPair(Type *LoadSizeType, bool NeedsBSwap, Type *CmpSizeType,
- unsigned OffsetBytes);
+ LoadPair getLoadPair(Type *LoadSizeType, Type *BSwapSizeType,
+ Type *CmpSizeType, unsigned OffsetBytes);
static LoadEntryVector
computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
@@ -128,6 +130,11 @@ class MemCmpExpansion {
unsigned MaxNumLoads,
unsigned &NumLoadsNonOneByte);
+ static void optimiseLoadSequence(
+ LoadEntryVector &LoadSequence,
+ const TargetTransformInfo::MemCmpExpansionOptions &Options,
+ bool IsUsedForZeroCmp);
+
public:
MemCmpExpansion(CallInst *CI, uint64_t Size,
const TargetTransformInfo::MemCmpExpansionOptions &Options,
@@ -210,6 +217,37 @@ MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size,
return LoadSequence;
}
+void MemCmpExpansion::optimiseLoadSequence(
+ LoadEntryVector &LoadSequence,
+ const TargetTransformInfo::MemCmpExpansionOptions &Options,
+ bool IsUsedForZeroCmp) {
+ // This part of code attempts to optimize the LoadSequence by merging allowed
+ // subsequences into single loads of allowed sizes from
+ // `MemCmpExpansionOptions::AllowedTailExpansions`. If it is for zero
+ // comparison or if no allowed tail expansions are specified, we exit early.
+ if (IsUsedForZeroCmp || Options.AllowedTailExpansions.empty())
+ return;
+
+ while (LoadSequence.size() >= 2) {
+ auto Last = LoadSequence[LoadSequence.size() - 1];
+ auto PreLast = LoadSequence[LoadSequence.size() - 2];
+
+ // Exit the loop if the two sequences are not contiguous
+ if (PreLast.Offset + PreLast.LoadSize != Last.Offset)
+ break;
+
+ auto LoadSize = Last.LoadSize + PreLast.LoadSize;
+ if (find(Options.AllowedTailExpansions, LoadSize) ==
+ Options.AllowedTailExpansions.end())
+ break;
+
+ // Remove the last two sequences and replace with the combined sequence
+ LoadSequence.pop_back();
+ LoadSequence.pop_back();
+ LoadSequence.emplace_back(PreLast.Offset, LoadSize);
+ }
+}
+
// Initialize the basic block structure required for expansion of memcmp call
// with given maximum load size and memcmp size parameter.
// This structure includes:
@@ -255,6 +293,7 @@ MemCmpExpansion::MemCmpExpansion(
}
}
assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant");
+ optimiseLoadSequence(LoadSequence, Options, IsUsedForZeroCmp);
}
unsigned MemCmpExpansion::getNumBlocks() {
@@ -278,7 +317,7 @@ void MemCmpExpansion::createResultBlock() {
}
MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
- bool NeedsBSwap,
+ Type *BSwapSizeType,
Type *CmpSizeType,
unsigned OffsetBytes) {
// Get the memory source at offset `OffsetBytes`.
@@ -307,16 +346,22 @@ MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
if (!Rhs)
Rhs = Builder.CreateAlignedLoad(LoadSizeType, RhsSource, RhsAlign);
+ // Zero extend if Byte Swap intrinsic has different type
+ if (BSwapSizeType && LoadSizeType != BSwapSizeType) {
+ Lhs = Builder.CreateZExt(Lhs, BSwapSizeType);
+ Rhs = Builder.CreateZExt(Rhs, BSwapSizeType);
+ }
+
// Swap bytes if required.
- if (NeedsBSwap) {
- Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
- Intrinsic::bswap, LoadSizeType);
+ if (BSwapSizeType) {
+ Function *Bswap = Intrinsic::getDeclaration(
+ CI->getModule(), Intrinsic::bswap, BSwapSizeType);
Lhs = Builder.CreateCall(Bswap, Lhs);
Rhs = Builder.CreateCall(Bswap, Rhs);
}
// Zero extend if required.
- if (CmpSizeType != nullptr && CmpSizeType != LoadSizeType) {
+ if (CmpSizeType != nullptr && CmpSizeType != Lhs->getType()) {
Lhs = Builder.CreateZExt(Lhs, CmpSizeType);
Rhs = Builder.CreateZExt(Rhs, CmpSizeType);
}
@@ -332,7 +377,7 @@ void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
BasicBlock *BB = LoadCmpBlocks[BlockIndex];
Builder.SetInsertPoint(BB);
const LoadPair Loads =
- getLoadPair(Type::getInt8Ty(CI->getContext()), /*NeedsBSwap=*/false,
+ getLoadPair(Type::getInt8Ty(CI->getContext()), nullptr,
Type::getInt32Ty(CI->getContext()), OffsetBytes);
Value *Diff = Builder.CreateSub(Loads.Lhs, Loads.Rhs);
@@ -385,11 +430,12 @@ Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
IntegerType *const MaxLoadType =
NumLoads == 1 ? nullptr
: IntegerType::get(CI->getContext(), MaxLoadSize * 8);
+
for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
const LoadPair Loads = getLoadPair(
- IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8),
- /*NeedsBSwap=*/false, MaxLoadType, CurLoadEntry.Offset);
+ IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8), nullptr,
+ MaxLoadType, CurLoadEntry.Offset);
if (NumLoads != 1) {
// If we have multiple loads per block, we need to generate a composite
@@ -475,14 +521,20 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
Type *LoadSizeType =
IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
- Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
+ Type *BSwapSizeType =
+ DL.isLittleEndian()
+ ? IntegerType::get(CI->getContext(),
+ PowerOf2Ceil(CurLoadEntry.LoadSize * 8))
+ : nullptr;
+ Type *MaxLoadType = IntegerType::get(
+ CI->getContext(),
+ std::max(MaxLoadSize, (unsigned)PowerOf2Ceil(CurLoadEntry.LoadSize)) * 8);
assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");
Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
- const LoadPair Loads =
- getLoadPair(LoadSizeType, /*NeedsBSwap=*/DL.isLittleEndian(), MaxLoadType,
- CurLoadEntry.Offset);
+ const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType, MaxLoadType,
+ CurLoadEntry.Offset);
// Add the loaded values to the phi nodes for calculating memcmp result only
// if result is not used in a zero equality.
@@ -558,7 +610,7 @@ void MemCmpExpansion::setupResultBlockPHINodes() {
}
void MemCmpExpansion::setupEndBlockPHINodes() {
- Builder.SetInsertPoint(&EndBlock->front());
+ Builder.SetInsertPoint(EndBlock, EndBlock->begin());
PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res");
}
@@ -586,21 +638,63 @@ Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
/// A memcmp expansion that only has one block of load and compare can bypass
/// the compare, branch, and phi IR that is required in the general case.
+/// This function also analyses users of memcmp, and if there is only one user
+/// from which we can conclude that only 2 out of 3 memcmp outcomes really
+/// matter, then it generates more efficient code with only one comparison.
Value *MemCmpExpansion::getMemCmpOneBlock() {
- Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
bool NeedsBSwap = DL.isLittleEndian() && Size != 1;
+ Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
+ Type *BSwapSizeType =
+ NeedsBSwap ? IntegerType::get(CI->getContext(), PowerOf2Ceil(Size * 8))
+ : nullptr;
+ Type *MaxLoadType =
+ IntegerType::get(CI->getContext(),
+ std::max(MaxLoadSize, (unsigned)PowerOf2Ceil(Size)) * 8);
// The i8 and i16 cases don't need compares. We zext the loaded values and
// subtract them to get the suitable negative, zero, or positive i32 result.
- if (Size < 4) {
- const LoadPair Loads =
- getLoadPair(LoadSizeType, NeedsBSwap, Builder.getInt32Ty(),
- /*Offset*/ 0);
+ if (Size == 1 || Size == 2) {
+ const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType,
+ Builder.getInt32Ty(), /*Offset*/ 0);
return Builder.CreateSub(Loads.Lhs, Loads.Rhs);
}
- const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, LoadSizeType,
+ const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType, MaxLoadType,
/*Offset*/ 0);
+
+ // If a user of memcmp cares only about two outcomes, for example:
+ // bool result = memcmp(a, b, NBYTES) > 0;
+ // We can generate more optimal code with a smaller number of operations
+ if (CI->hasOneUser()) {
+ auto *UI = cast<Instruction>(*CI->user_begin());
+ ICmpInst::Predicate Pred = ICmpInst::Predicate::BAD_ICMP_PREDICATE;
+ uint64_t Shift;
+ bool NeedsZExt = false;
+ // This is a special case because instead of checking if the result is less
+ // than zero:
+ // bool result = memcmp(a, b, NBYTES) < 0;
+ // Compiler is clever enough to generate the following code:
+ // bool result = memcmp(a, b, NBYTES) >> 31;
+ if (match(UI, m_LShr(m_Value(), m_ConstantInt(Shift))) &&
+ Shift == (CI->getType()->getIntegerBitWidth() - 1)) {
+ Pred = ICmpInst::ICMP_SLT;
+ NeedsZExt = true;
+ } else {
+ // In case of a successful match this call will set `Pred` variable
+ match(UI, m_ICmp(Pred, m_Specific(CI), m_Zero()));
+ }
+ // Generate new code and remove the original memcmp call and the user
+ if (ICmpInst::isSigned(Pred)) {
+ Value *Cmp = Builder.CreateICmp(CmpInst::getUnsignedPredicate(Pred),
+ Loads.Lhs, Loads.Rhs);
+ auto *Result = NeedsZExt ? Builder.CreateZExt(Cmp, UI->getType()) : Cmp;
+ UI->replaceAllUsesWith(Result);
+ UI->eraseFromParent();
+ CI->eraseFromParent();
+ return nullptr;
+ }
+ }
+
// The result of memcmp is negative, zero, or positive, so produce that by
// subtracting 2 extended compare bits: sub (ugt, ult).
// If a target prefers to use selects to get -1/0/1, they should be able
@@ -615,7 +709,7 @@ Value *MemCmpExpansion::getMemCmpOneBlock() {
}
// This function expands the memcmp call into an inline expansion and returns
-// the memcmp result.
+// the memcmp result. Returns nullptr if the memcmp is already replaced.
Value *MemCmpExpansion::getMemCmpExpansion() {
// Create the basic block framework for a multi-block expansion.
if (getNumBlocks() != 1) {
@@ -783,11 +877,11 @@ static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI,
NumMemCmpInlined++;
- Value *Res = Expansion.getMemCmpExpansion();
-
- // Replace call with result of expansion and erase call.
- CI->replaceAllUsesWith(Res);
- CI->eraseFromParent();
+ if (Value *Res = Expansion.getMemCmpExpansion()) {
+ // Replace call with result of expansion and erase call.
+ CI->replaceAllUsesWith(Res);
+ CI->eraseFromParent();
+ }
return true;
}