summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp43
1 files changed, 28 insertions, 15 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index fbff5dd4a8cd5..0a842b4e10475 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -23,8 +23,11 @@ using namespace PatternMatch;
// Given pattern:
// (x shiftopcode Q) shiftopcode K
// we should rewrite it as
-// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x)
-// This is valid for any shift, but they must be identical.
+// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) and
+//
+// This is valid for any shift, but they must be identical, and we must be
+// careful in case we have (zext(Q)+zext(K)) and look past extensions,
+// (Q+K) must not overflow or else (Q+K) u< bitwidth(x) is bogus.
//
// AnalyzeForSignBitExtraction indicates that we will only analyze whether this
// pattern has any 2 right-shifts that sum to 1 less than original bit width.
@@ -58,6 +61,23 @@ Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts(
if (ShAmt0->getType() != ShAmt1->getType())
return nullptr;
+ // As input, we have the following pattern:
+ // Sh0 (Sh1 X, Q), K
+ // We want to rewrite that as:
+ // Sh x, (Q+K) iff (Q+K) u< bitwidth(x)
+ // While we know that originally (Q+K) would not overflow
+ // (because 2 * (N-1) u<= iN -1), we have looked past extensions of
+ // shift amounts. so it may now overflow in smaller bitwidth.
+ // To ensure that does not happen, we need to ensure that the total maximal
+ // shift amount is still representable in that smaller bit width.
+ unsigned MaximalPossibleTotalShiftAmount =
+ (Sh0->getType()->getScalarSizeInBits() - 1) +
+ (Sh1->getType()->getScalarSizeInBits() - 1);
+ APInt MaximalRepresentableShiftAmount =
+ APInt::getAllOnesValue(ShAmt0->getType()->getScalarSizeInBits());
+ if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount))
+ return nullptr;
+
// We are only looking for signbit extraction if we have two right shifts.
bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) &&
match(Sh1, m_Shr(m_Value(), m_Value()));
@@ -388,8 +408,7 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
// demand the sign bit (and many others) here??
Value *Rem = Builder.CreateAnd(A, ConstantInt::get(I.getType(), *B - 1),
Op1->getName());
- I.setOperand(1, Rem);
- return &I;
+ return replaceOperand(I, 1, Rem);
}
if (Instruction *Logic = foldShiftOfShiftedLogic(I, Builder))
@@ -593,19 +612,13 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
// We can always evaluate constants shifted.
if (Constant *C = dyn_cast<Constant>(V)) {
if (isLeftShift)
- V = IC.Builder.CreateShl(C, NumBits);
+ return IC.Builder.CreateShl(C, NumBits);
else
- V = IC.Builder.CreateLShr(C, NumBits);
- // If we got a constantexpr back, try to simplify it with TD info.
- if (auto *C = dyn_cast<Constant>(V))
- if (auto *FoldedC =
- ConstantFoldConstant(C, DL, &IC.getTargetLibraryInfo()))
- V = FoldedC;
- return V;
+ return IC.Builder.CreateLShr(C, NumBits);
}
Instruction *I = cast<Instruction>(V);
- IC.Worklist.Add(I);
+ IC.Worklist.push(I);
switch (I->getOpcode()) {
default: llvm_unreachable("Inconsistency with CanEvaluateShifted");
@@ -761,7 +774,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
Constant *Mask = ConstantInt::get(I.getContext(), Bits);
if (VectorType *VT = dyn_cast<VectorType>(X->getType()))
- Mask = ConstantVector::getSplat(VT->getNumElements(), Mask);
+ Mask = ConstantVector::getSplat(VT->getElementCount(), Mask);
return BinaryOperator::CreateAnd(X, Mask);
}
@@ -796,7 +809,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
Constant *Mask = ConstantInt::get(I.getContext(), Bits);
if (VectorType *VT = dyn_cast<VectorType>(X->getType()))
- Mask = ConstantVector::getSplat(VT->getNumElements(), Mask);
+ Mask = ConstantVector::getSplat(VT->getElementCount(), Mask);
return BinaryOperator::CreateAnd(X, Mask);
}