diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp | 43 |
1 files changed, 28 insertions, 15 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index fbff5dd4a8cd5..0a842b4e10475 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -23,8 +23,11 @@ using namespace PatternMatch; // Given pattern: // (x shiftopcode Q) shiftopcode K // we should rewrite it as -// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) -// This is valid for any shift, but they must be identical. +// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) and +// +// This is valid for any shift, but they must be identical, and we must be +// careful in case we have (zext(Q)+zext(K)) and look past extensions, +// (Q+K) must not overflow or else (Q+K) u< bitwidth(x) is bogus. // // AnalyzeForSignBitExtraction indicates that we will only analyze whether this // pattern has any 2 right-shifts that sum to 1 less than original bit width. @@ -58,6 +61,23 @@ Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts( if (ShAmt0->getType() != ShAmt1->getType()) return nullptr; + // As input, we have the following pattern: + // Sh0 (Sh1 X, Q), K + // We want to rewrite that as: + // Sh x, (Q+K) iff (Q+K) u< bitwidth(x) + // While we know that originally (Q+K) would not overflow + // (because 2 * (N-1) u<= iN -1), we have looked past extensions of + // shift amounts. so it may now overflow in smaller bitwidth. + // To ensure that does not happen, we need to ensure that the total maximal + // shift amount is still representable in that smaller bit width. + unsigned MaximalPossibleTotalShiftAmount = + (Sh0->getType()->getScalarSizeInBits() - 1) + + (Sh1->getType()->getScalarSizeInBits() - 1); + APInt MaximalRepresentableShiftAmount = + APInt::getAllOnesValue(ShAmt0->getType()->getScalarSizeInBits()); + if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount)) + return nullptr; + // We are only looking for signbit extraction if we have two right shifts. bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) && match(Sh1, m_Shr(m_Value(), m_Value())); @@ -388,8 +408,7 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { // demand the sign bit (and many others) here?? Value *Rem = Builder.CreateAnd(A, ConstantInt::get(I.getType(), *B - 1), Op1->getName()); - I.setOperand(1, Rem); - return &I; + return replaceOperand(I, 1, Rem); } if (Instruction *Logic = foldShiftOfShiftedLogic(I, Builder)) @@ -593,19 +612,13 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, // We can always evaluate constants shifted. if (Constant *C = dyn_cast<Constant>(V)) { if (isLeftShift) - V = IC.Builder.CreateShl(C, NumBits); + return IC.Builder.CreateShl(C, NumBits); else - V = IC.Builder.CreateLShr(C, NumBits); - // If we got a constantexpr back, try to simplify it with TD info. - if (auto *C = dyn_cast<Constant>(V)) - if (auto *FoldedC = - ConstantFoldConstant(C, DL, &IC.getTargetLibraryInfo())) - V = FoldedC; - return V; + return IC.Builder.CreateLShr(C, NumBits); } Instruction *I = cast<Instruction>(V); - IC.Worklist.Add(I); + IC.Worklist.push(I); switch (I->getOpcode()) { default: llvm_unreachable("Inconsistency with CanEvaluateShifted"); @@ -761,7 +774,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); Constant *Mask = ConstantInt::get(I.getContext(), Bits); if (VectorType *VT = dyn_cast<VectorType>(X->getType())) - Mask = ConstantVector::getSplat(VT->getNumElements(), Mask); + Mask = ConstantVector::getSplat(VT->getElementCount(), Mask); return BinaryOperator::CreateAnd(X, Mask); } @@ -796,7 +809,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); Constant *Mask = ConstantInt::get(I.getContext(), Bits); if (VectorType *VT = dyn_cast<VectorType>(X->getType())) - Mask = ConstantVector::getSplat(VT->getNumElements(), Mask); + Mask = ConstantVector::getSplat(VT->getElementCount(), Mask); return BinaryOperator::CreateAnd(X, Mask); } |