aboutsummaryrefslogtreecommitdiff
path: root/lib/Transforms/InstCombine
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2017-06-16 21:03:24 +0000
committerDimitry Andric <dim@FreeBSD.org>2017-06-16 21:03:24 +0000
commit7c7aba6e5fef47a01a136be655b0a92cfd7090f6 (patch)
tree99ec531924f6078534b100ab9d7696abce848099 /lib/Transforms/InstCombine
parent7ab83427af0f77b59941ceba41d509d7d097b065 (diff)
Notes
Diffstat (limited to 'lib/Transforms/InstCombine')
-rw-r--r--lib/Transforms/InstCombine/InstCombineAndOrXor.cpp113
-rw-r--r--lib/Transforms/InstCombine/InstCombineCalls.cpp113
-rw-r--r--lib/Transforms/InstCombine/InstCombineInternal.h9
-rw-r--r--lib/Transforms/InstCombine/InstCombineShifts.cpp14
4 files changed, 144 insertions, 105 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 4fe3225a2172..a881bda5ba98 100644
--- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -763,8 +763,54 @@ foldAndOrOfEqualityCmpsWithConstants(ICmpInst *LHS, ICmpInst *RHS,
return nullptr;
}
+// Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2)
+// Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2)
+Value *InstCombiner::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS,
+ bool JoinedByAnd,
+ Instruction &CxtI) {
+ ICmpInst::Predicate Pred = LHS->getPredicate();
+ if (Pred != RHS->getPredicate())
+ return nullptr;
+ if (JoinedByAnd && Pred != ICmpInst::ICMP_NE)
+ return nullptr;
+ if (!JoinedByAnd && Pred != ICmpInst::ICMP_EQ)
+ return nullptr;
+
+ // TODO support vector splats
+ ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1));
+ ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1));
+ if (!LHSC || !RHSC || !LHSC->isZero() || !RHSC->isZero())
+ return nullptr;
+
+ Value *A, *B, *C, *D;
+ if (match(LHS->getOperand(0), m_And(m_Value(A), m_Value(B))) &&
+ match(RHS->getOperand(0), m_And(m_Value(C), m_Value(D)))) {
+ if (A == D || B == D)
+ std::swap(C, D);
+ if (B == C)
+ std::swap(A, B);
+
+ if (A == C &&
+ isKnownToBeAPowerOfTwo(B, false, 0, &CxtI) &&
+ isKnownToBeAPowerOfTwo(D, false, 0, &CxtI)) {
+ Value *Mask = Builder->CreateOr(B, D);
+ Value *Masked = Builder->CreateAnd(A, Mask);
+ auto NewPred = JoinedByAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
+ return Builder->CreateICmp(NewPred, Masked, Mask);
+ }
+ }
+
+ return nullptr;
+}
+
/// Fold (icmp)&(icmp) if possible.
-Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) {
+Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS,
+ Instruction &CxtI) {
+ // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2)
+ // if K1 and K2 are a one-bit mask.
+ if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, true, CxtI))
+ return V;
+
ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
// (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B)
@@ -1127,8 +1173,8 @@ Instruction *InstCombiner::foldCastedBitwiseLogic(BinaryOperator &I) {
ICmpInst *ICmp0 = dyn_cast<ICmpInst>(Cast0Src);
ICmpInst *ICmp1 = dyn_cast<ICmpInst>(Cast1Src);
if (ICmp0 && ICmp1) {
- Value *Res = LogicOpc == Instruction::And ? foldAndOfICmps(ICmp0, ICmp1)
- : foldOrOfICmps(ICmp0, ICmp1, &I);
+ Value *Res = LogicOpc == Instruction::And ? foldAndOfICmps(ICmp0, ICmp1, I)
+ : foldOrOfICmps(ICmp0, ICmp1, I);
if (Res)
return CastInst::Create(CastOpcode, Res, DestTy);
return nullptr;
@@ -1426,7 +1472,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
ICmpInst *LHS = dyn_cast<ICmpInst>(Op0);
ICmpInst *RHS = dyn_cast<ICmpInst>(Op1);
if (LHS && RHS)
- if (Value *Res = foldAndOfICmps(LHS, RHS))
+ if (Value *Res = foldAndOfICmps(LHS, RHS, I))
return replaceInstUsesWith(I, Res);
// TODO: Make this recursive; it's a little tricky because an arbitrary
@@ -1434,18 +1480,18 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
Value *X, *Y;
if (LHS && match(Op1, m_OneUse(m_And(m_Value(X), m_Value(Y))))) {
if (auto *Cmp = dyn_cast<ICmpInst>(X))
- if (Value *Res = foldAndOfICmps(LHS, Cmp))
+ if (Value *Res = foldAndOfICmps(LHS, Cmp, I))
return replaceInstUsesWith(I, Builder->CreateAnd(Res, Y));
if (auto *Cmp = dyn_cast<ICmpInst>(Y))
- if (Value *Res = foldAndOfICmps(LHS, Cmp))
+ if (Value *Res = foldAndOfICmps(LHS, Cmp, I))
return replaceInstUsesWith(I, Builder->CreateAnd(Res, X));
}
if (RHS && match(Op0, m_OneUse(m_And(m_Value(X), m_Value(Y))))) {
if (auto *Cmp = dyn_cast<ICmpInst>(X))
- if (Value *Res = foldAndOfICmps(Cmp, RHS))
+ if (Value *Res = foldAndOfICmps(Cmp, RHS, I))
return replaceInstUsesWith(I, Builder->CreateAnd(Res, Y));
if (auto *Cmp = dyn_cast<ICmpInst>(Y))
- if (Value *Res = foldAndOfICmps(Cmp, RHS))
+ if (Value *Res = foldAndOfICmps(Cmp, RHS, I))
return replaceInstUsesWith(I, Builder->CreateAnd(Res, X));
}
}
@@ -1591,41 +1637,16 @@ static Value *matchSelectFromAndOr(Value *A, Value *C, Value *B, Value *D,
/// Fold (icmp)|(icmp) if possible.
Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
- Instruction *CxtI) {
- ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
-
+ Instruction &CxtI) {
// Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2)
// if K1 and K2 are a one-bit mask.
- ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1));
- ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1));
-
- if (LHS->getPredicate() == ICmpInst::ICMP_EQ && LHSC && LHSC->isZero() &&
- RHS->getPredicate() == ICmpInst::ICMP_EQ && RHSC && RHSC->isZero()) {
-
- BinaryOperator *LAnd = dyn_cast<BinaryOperator>(LHS->getOperand(0));
- BinaryOperator *RAnd = dyn_cast<BinaryOperator>(RHS->getOperand(0));
- if (LAnd && RAnd && LAnd->hasOneUse() && RHS->hasOneUse() &&
- LAnd->getOpcode() == Instruction::And &&
- RAnd->getOpcode() == Instruction::And) {
+ if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, false, CxtI))
+ return V;
- Value *Mask = nullptr;
- Value *Masked = nullptr;
- if (LAnd->getOperand(0) == RAnd->getOperand(0) &&
- isKnownToBeAPowerOfTwo(LAnd->getOperand(1), false, 0, CxtI) &&
- isKnownToBeAPowerOfTwo(RAnd->getOperand(1), false, 0, CxtI)) {
- Mask = Builder->CreateOr(LAnd->getOperand(1), RAnd->getOperand(1));
- Masked = Builder->CreateAnd(LAnd->getOperand(0), Mask);
- } else if (LAnd->getOperand(1) == RAnd->getOperand(1) &&
- isKnownToBeAPowerOfTwo(LAnd->getOperand(0), false, 0, CxtI) &&
- isKnownToBeAPowerOfTwo(RAnd->getOperand(0), false, 0, CxtI)) {
- Mask = Builder->CreateOr(LAnd->getOperand(0), RAnd->getOperand(0));
- Masked = Builder->CreateAnd(LAnd->getOperand(1), Mask);
- }
+ ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
- if (Masked)
- return Builder->CreateICmp(ICmpInst::ICMP_NE, Masked, Mask);
- }
- }
+ ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1));
+ ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1));
// Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3)
// --> (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3)
@@ -2117,12 +2138,16 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
}
// (A ^ B) | ((B ^ C) ^ A) -> (A ^ B) | C
+ // FIXME: The two hasOneUse calls here are the same call, maybe we were
+ // supposed to check Op1->operand(0)?
if (match(Op0, m_Xor(m_Value(A), m_Value(B))))
if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A))))
if (Op1->hasOneUse() || cast<BinaryOperator>(Op1)->hasOneUse())
return BinaryOperator::CreateOr(Op0, C);
// ((A ^ C) ^ B) | (B ^ A) -> (B ^ A) | C
+ // FIXME: The two hasOneUse calls here are the same call, maybe we were
+ // supposed to check Op0->operand(0)?
if (match(Op0, m_Xor(m_Xor(m_Value(A), m_Value(C)), m_Value(B))))
if (match(Op1, m_Xor(m_Specific(B), m_Specific(A))))
if (Op0->hasOneUse() || cast<BinaryOperator>(Op0)->hasOneUse())
@@ -2194,7 +2219,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
ICmpInst *LHS = dyn_cast<ICmpInst>(Op0);
ICmpInst *RHS = dyn_cast<ICmpInst>(Op1);
if (LHS && RHS)
- if (Value *Res = foldOrOfICmps(LHS, RHS, &I))
+ if (Value *Res = foldOrOfICmps(LHS, RHS, I))
return replaceInstUsesWith(I, Res);
// TODO: Make this recursive; it's a little tricky because an arbitrary
@@ -2202,18 +2227,18 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
Value *X, *Y;
if (LHS && match(Op1, m_OneUse(m_Or(m_Value(X), m_Value(Y))))) {
if (auto *Cmp = dyn_cast<ICmpInst>(X))
- if (Value *Res = foldOrOfICmps(LHS, Cmp, &I))
+ if (Value *Res = foldOrOfICmps(LHS, Cmp, I))
return replaceInstUsesWith(I, Builder->CreateOr(Res, Y));
if (auto *Cmp = dyn_cast<ICmpInst>(Y))
- if (Value *Res = foldOrOfICmps(LHS, Cmp, &I))
+ if (Value *Res = foldOrOfICmps(LHS, Cmp, I))
return replaceInstUsesWith(I, Builder->CreateOr(Res, X));
}
if (RHS && match(Op0, m_OneUse(m_Or(m_Value(X), m_Value(Y))))) {
if (auto *Cmp = dyn_cast<ICmpInst>(X))
- if (Value *Res = foldOrOfICmps(Cmp, RHS, &I))
+ if (Value *Res = foldOrOfICmps(Cmp, RHS, I))
return replaceInstUsesWith(I, Builder->CreateOr(Res, Y));
if (auto *Cmp = dyn_cast<ICmpInst>(Y))
- if (Value *Res = foldOrOfICmps(Cmp, RHS, &I))
+ if (Value *Res = foldOrOfICmps(Cmp, RHS, I))
return replaceInstUsesWith(I, Builder->CreateOr(Res, X));
}
}
diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp
index d29ed49eca0b..c0830a5d2112 100644
--- a/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -94,75 +94,80 @@ static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) {
return ConstantVector::get(BoolVec);
}
-Instruction *
-InstCombiner::SimplifyElementAtomicMemCpy(ElementAtomicMemCpyInst *AMI) {
+Instruction *InstCombiner::SimplifyElementUnorderedAtomicMemCpy(
+ ElementUnorderedAtomicMemCpyInst *AMI) {
// Try to unfold this intrinsic into sequence of explicit atomic loads and
// stores.
// First check that number of elements is compile time constant.
- auto *NumElementsCI = dyn_cast<ConstantInt>(AMI->getNumElements());
- if (!NumElementsCI)
+ auto *LengthCI = dyn_cast<ConstantInt>(AMI->getLength());
+ if (!LengthCI)
return nullptr;
// Check that there are not too many elements.
- uint64_t NumElements = NumElementsCI->getZExtValue();
+ uint64_t LengthInBytes = LengthCI->getZExtValue();
+ uint32_t ElementSizeInBytes = AMI->getElementSizeInBytes();
+ uint64_t NumElements = LengthInBytes / ElementSizeInBytes;
if (NumElements >= UnfoldElementAtomicMemcpyMaxElements)
return nullptr;
- // Don't unfold into illegal integers
- uint64_t ElementSizeInBytes = AMI->getElementSizeInBytes() * 8;
- if (!getDataLayout().isLegalInteger(ElementSizeInBytes))
- return nullptr;
+ // Only expand if there are elements to copy.
+ if (NumElements > 0) {
+ // Don't unfold into illegal integers
+ uint64_t ElementSizeInBits = ElementSizeInBytes * 8;
+ if (!getDataLayout().isLegalInteger(ElementSizeInBits))
+ return nullptr;
- // Cast source and destination to the correct type. Intrinsic input arguments
- // are usually represented as i8*.
- // Often operands will be explicitly casted to i8* and we can just strip
- // those casts instead of inserting new ones. However it's easier to rely on
- // other InstCombine rules which will cover trivial cases anyway.
- Value *Src = AMI->getRawSource();
- Value *Dst = AMI->getRawDest();
- Type *ElementPointerType = Type::getIntNPtrTy(
- AMI->getContext(), ElementSizeInBytes, Src->getType()->getPointerAddressSpace());
+ // Cast source and destination to the correct type. Intrinsic input
+ // arguments are usually represented as i8*. Often operands will be
+ // explicitly casted to i8* and we can just strip those casts instead of
+ // inserting new ones. However it's easier to rely on other InstCombine
+ // rules which will cover trivial cases anyway.
+ Value *Src = AMI->getRawSource();
+ Value *Dst = AMI->getRawDest();
+ Type *ElementPointerType =
+ Type::getIntNPtrTy(AMI->getContext(), ElementSizeInBits,
+ Src->getType()->getPointerAddressSpace());
- Value *SrcCasted = Builder->CreatePointerCast(Src, ElementPointerType,
- "memcpy_unfold.src_casted");
- Value *DstCasted = Builder->CreatePointerCast(Dst, ElementPointerType,
- "memcpy_unfold.dst_casted");
+ Value *SrcCasted = Builder->CreatePointerCast(Src, ElementPointerType,
+ "memcpy_unfold.src_casted");
+ Value *DstCasted = Builder->CreatePointerCast(Dst, ElementPointerType,
+ "memcpy_unfold.dst_casted");
- for (uint64_t i = 0; i < NumElements; ++i) {
- // Get current element addresses
- ConstantInt *ElementIdxCI =
- ConstantInt::get(AMI->getContext(), APInt(64, i));
- Value *SrcElementAddr =
- Builder->CreateGEP(SrcCasted, ElementIdxCI, "memcpy_unfold.src_addr");
- Value *DstElementAddr =
- Builder->CreateGEP(DstCasted, ElementIdxCI, "memcpy_unfold.dst_addr");
+ for (uint64_t i = 0; i < NumElements; ++i) {
+ // Get current element addresses
+ ConstantInt *ElementIdxCI =
+ ConstantInt::get(AMI->getContext(), APInt(64, i));
+ Value *SrcElementAddr =
+ Builder->CreateGEP(SrcCasted, ElementIdxCI, "memcpy_unfold.src_addr");
+ Value *DstElementAddr =
+ Builder->CreateGEP(DstCasted, ElementIdxCI, "memcpy_unfold.dst_addr");
- // Load from the source. Transfer alignment information and mark load as
- // unordered atomic.
- LoadInst *Load = Builder->CreateLoad(SrcElementAddr, "memcpy_unfold.val");
- Load->setOrdering(AtomicOrdering::Unordered);
- // We know alignment of the first element. It is also guaranteed by the
- // verifier that element size is less or equal than first element alignment
- // and both of this values are powers of two.
- // This means that all subsequent accesses are at least element size
- // aligned.
- // TODO: We can infer better alignment but there is no evidence that this
- // will matter.
- Load->setAlignment(i == 0 ? AMI->getSrcAlignment()
- : AMI->getElementSizeInBytes());
- Load->setDebugLoc(AMI->getDebugLoc());
+ // Load from the source. Transfer alignment information and mark load as
+ // unordered atomic.
+ LoadInst *Load = Builder->CreateLoad(SrcElementAddr, "memcpy_unfold.val");
+ Load->setOrdering(AtomicOrdering::Unordered);
+ // We know alignment of the first element. It is also guaranteed by the
+ // verifier that element size is less or equal than first element
+ // alignment and both of this values are powers of two. This means that
+ // all subsequent accesses are at least element size aligned.
+ // TODO: We can infer better alignment but there is no evidence that this
+ // will matter.
+ Load->setAlignment(i == 0 ? AMI->getParamAlignment(1)
+ : ElementSizeInBytes);
+ Load->setDebugLoc(AMI->getDebugLoc());
- // Store loaded value via unordered atomic store.
- StoreInst *Store = Builder->CreateStore(Load, DstElementAddr);
- Store->setOrdering(AtomicOrdering::Unordered);
- Store->setAlignment(i == 0 ? AMI->getDstAlignment()
- : AMI->getElementSizeInBytes());
- Store->setDebugLoc(AMI->getDebugLoc());
+ // Store loaded value via unordered atomic store.
+ StoreInst *Store = Builder->CreateStore(Load, DstElementAddr);
+ Store->setOrdering(AtomicOrdering::Unordered);
+ Store->setAlignment(i == 0 ? AMI->getParamAlignment(0)
+ : ElementSizeInBytes);
+ Store->setDebugLoc(AMI->getDebugLoc());
+ }
}
// Set the number of elements of the copy to 0, it will be deleted on the
// next iteration.
- AMI->setNumElements(Constant::getNullValue(NumElementsCI->getType()));
+ AMI->setLength(Constant::getNullValue(LengthCI->getType()));
return AMI;
}
@@ -1888,12 +1893,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
if (Changed) return II;
}
- if (auto *AMI = dyn_cast<ElementAtomicMemCpyInst>(II)) {
- if (Constant *C = dyn_cast<Constant>(AMI->getNumElements()))
+ if (auto *AMI = dyn_cast<ElementUnorderedAtomicMemCpyInst>(II)) {
+ if (Constant *C = dyn_cast<Constant>(AMI->getLength()))
if (C->isNullValue())
return eraseInstFromFunction(*AMI);
- if (Instruction *I = SimplifyElementAtomicMemCpy(AMI))
+ if (Instruction *I = SimplifyElementUnorderedAtomicMemCpy(AMI))
return I;
}
diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h
index fd0a64a5bbb5..1a7db146df42 100644
--- a/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -447,12 +447,14 @@ private:
Instruction::CastOps isEliminableCastPair(const CastInst *CI1,
const CastInst *CI2);
- Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS);
+ Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI);
Value *foldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS);
- Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction *CxtI);
+ Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI);
Value *foldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS);
Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS);
+ Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS,
+ bool JoinedByAnd, Instruction &CxtI);
public:
/// \brief Inserts an instruction \p New before instruction \p Old
///
@@ -724,7 +726,8 @@ private:
Instruction *MatchBSwap(BinaryOperator &I);
bool SimplifyStoreAtEndOfBlock(StoreInst &SI);
- Instruction *SimplifyElementAtomicMemCpy(ElementAtomicMemCpyInst *AMI);
+ Instruction *
+ SimplifyElementUnorderedAtomicMemCpy(ElementUnorderedAtomicMemCpyInst *AMI);
Instruction *SimplifyMemTransfer(MemIntrinsic *MI);
Instruction *SimplifyMemSet(MemSetInst *MI);
diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 3f2ddcacce2b..8cec865c6422 100644
--- a/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -682,11 +682,11 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
}
- if (match(Op0, m_SExt(m_Value(X)))) {
+ if (match(Op0, m_SExt(m_Value(X))) &&
+ (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) {
// Are we moving the sign bit to the low bit and widening with high zeros?
unsigned SrcTyBitWidth = X->getType()->getScalarSizeInBits();
- if (ShAmt == BitWidth - 1 &&
- (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) {
+ if (ShAmt == BitWidth - 1) {
// lshr (sext i1 X to iN), N-1 --> zext X to iN
if (SrcTyBitWidth == 1)
return new ZExtInst(X, Ty);
@@ -698,7 +698,13 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
}
}
- // TODO: Convert to ashr+zext if the shift equals the extension amount.
+ // lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN
+ if (ShAmt == BitWidth - SrcTyBitWidth && Op0->hasOneUse()) {
+ // The new shift amount can't be more than the narrow source type.
+ unsigned NewShAmt = std::min(ShAmt, SrcTyBitWidth - 1);
+ Value *AShr = Builder->CreateAShr(X, NewShAmt);
+ return new ZExtInst(AShr, Ty);
+ }
}
if (match(Op0, m_LShr(m_Value(X), m_APInt(ShOp1)))) {