diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2020-02-27 18:58:42 +0000 |
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2020-02-27 18:58:42 +0000 |
| commit | 92d00d6a94bb341a1ed677031280e14863d4bb28 (patch) | |
| tree | ddcfe83581ea87d22873ea8523f444a56575f930 /llvm/lib/Transforms/InstCombine | |
| parent | d75c7debad4509ece98792074e64b8a650a27bdb (diff) | |
Notes
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
| -rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 20 | ||||
| -rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp | 24 |
2 files changed, 41 insertions, 3 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index f38dc436722d..e49e6cec65c0 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -3494,7 +3494,8 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, Instruction *NarrowestShift = XShift; Type *WidestTy = WidestShift->getType(); - assert(NarrowestShift->getType() == I.getOperand(0)->getType() && + Type *NarrowestTy = NarrowestShift->getType(); + assert(NarrowestTy == I.getOperand(0)->getType() && "We did not look past any shifts while matching XShift though."); bool HadTrunc = WidestTy != I.getOperand(0)->getType(); @@ -3533,6 +3534,23 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, if (XShAmt->getType() != YShAmt->getType()) return nullptr; + // As input, we have the following pattern: + // icmp eq/ne (and ((x shift Q), (y oppositeshift K))), 0 + // We want to rewrite that as: + // icmp eq/ne (and (x shift (Q+K)), y), 0 iff (Q+K) u< bitwidth(x) + // While we know that originally (Q+K) would not overflow + // (because 2 * (N-1) u<= iN -1), we have looked past extensions of + // shift amounts. so it may now overflow in smaller bitwidth. + // To ensure that does not happen, we need to ensure that the total maximal + // shift amount is still representable in that smaller bit width. + unsigned MaximalPossibleTotalShiftAmount = + (WidestTy->getScalarSizeInBits() - 1) + + (NarrowestTy->getScalarSizeInBits() - 1); + APInt MaximalRepresentableShiftAmount = + APInt::getAllOnesValue(XShAmt->getType()->getScalarSizeInBits()); + if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount)) + return nullptr; + // Can we fold (XShAmt+YShAmt) ? auto *NewShAmt = dyn_cast_or_null<Constant>( SimplifyAddInst(XShAmt, YShAmt, /*isNSW=*/false, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index fbff5dd4a8cd..739579e2d38e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -23,8 +23,11 @@ using namespace PatternMatch; // Given pattern: // (x shiftopcode Q) shiftopcode K // we should rewrite it as -// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) -// This is valid for any shift, but they must be identical. +// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) and +// +// This is valid for any shift, but they must be identical, and we must be +// careful in case we have (zext(Q)+zext(K)) and look past extensions, +// (Q+K) must not overflow or else (Q+K) u< bitwidth(x) is bogus. // // AnalyzeForSignBitExtraction indicates that we will only analyze whether this // pattern has any 2 right-shifts that sum to 1 less than original bit width. @@ -58,6 +61,23 @@ Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts( if (ShAmt0->getType() != ShAmt1->getType()) return nullptr; + // As input, we have the following pattern: + // Sh0 (Sh1 X, Q), K + // We want to rewrite that as: + // Sh x, (Q+K) iff (Q+K) u< bitwidth(x) + // While we know that originally (Q+K) would not overflow + // (because 2 * (N-1) u<= iN -1), we have looked past extensions of + // shift amounts. so it may now overflow in smaller bitwidth. + // To ensure that does not happen, we need to ensure that the total maximal + // shift amount is still representable in that smaller bit width. + unsigned MaximalPossibleTotalShiftAmount = + (Sh0->getType()->getScalarSizeInBits() - 1) + + (Sh1->getType()->getScalarSizeInBits() - 1); + APInt MaximalRepresentableShiftAmount = + APInt::getAllOnesValue(ShAmt0->getType()->getScalarSizeInBits()); + if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount)) + return nullptr; + // We are only looking for signbit extraction if we have two right shifts. bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) && match(Sh1, m_Shr(m_Value(), m_Value())); |
