summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2020-02-27 18:58:42 +0000
committerDimitry Andric <dim@FreeBSD.org>2020-02-27 18:58:42 +0000
commit92d00d6a94bb341a1ed677031280e14863d4bb28 (patch)
treeddcfe83581ea87d22873ea8523f444a56575f930 /llvm/lib/Transforms/InstCombine
parentd75c7debad4509ece98792074e64b8a650a27bdb (diff)
Notes
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp20
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp24
2 files changed, 41 insertions, 3 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index f38dc436722d..e49e6cec65c0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -3494,7 +3494,8 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ,
Instruction *NarrowestShift = XShift;
Type *WidestTy = WidestShift->getType();
- assert(NarrowestShift->getType() == I.getOperand(0)->getType() &&
+ Type *NarrowestTy = NarrowestShift->getType();
+ assert(NarrowestTy == I.getOperand(0)->getType() &&
"We did not look past any shifts while matching XShift though.");
bool HadTrunc = WidestTy != I.getOperand(0)->getType();
@@ -3533,6 +3534,23 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ,
if (XShAmt->getType() != YShAmt->getType())
return nullptr;
+ // As input, we have the following pattern:
+ // icmp eq/ne (and ((x shift Q), (y oppositeshift K))), 0
+ // We want to rewrite that as:
+ // icmp eq/ne (and (x shift (Q+K)), y), 0 iff (Q+K) u< bitwidth(x)
+ // While we know that originally (Q+K) would not overflow
+ // (because 2 * (N-1) u<= iN -1), we have looked past extensions of
+ // shift amounts. so it may now overflow in smaller bitwidth.
+ // To ensure that does not happen, we need to ensure that the total maximal
+ // shift amount is still representable in that smaller bit width.
+ unsigned MaximalPossibleTotalShiftAmount =
+ (WidestTy->getScalarSizeInBits() - 1) +
+ (NarrowestTy->getScalarSizeInBits() - 1);
+ APInt MaximalRepresentableShiftAmount =
+ APInt::getAllOnesValue(XShAmt->getType()->getScalarSizeInBits());
+ if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount))
+ return nullptr;
+
// Can we fold (XShAmt+YShAmt) ?
auto *NewShAmt = dyn_cast_or_null<Constant>(
SimplifyAddInst(XShAmt, YShAmt, /*isNSW=*/false,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index fbff5dd4a8cd..739579e2d38e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -23,8 +23,11 @@ using namespace PatternMatch;
// Given pattern:
// (x shiftopcode Q) shiftopcode K
// we should rewrite it as
-// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x)
-// This is valid for any shift, but they must be identical.
+// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) and
+//
+// This is valid for any shift, but they must be identical, and we must be
+// careful in case we have (zext(Q)+zext(K)) and look past extensions,
+// (Q+K) must not overflow or else (Q+K) u< bitwidth(x) is bogus.
//
// AnalyzeForSignBitExtraction indicates that we will only analyze whether this
// pattern has any 2 right-shifts that sum to 1 less than original bit width.
@@ -58,6 +61,23 @@ Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts(
if (ShAmt0->getType() != ShAmt1->getType())
return nullptr;
+ // As input, we have the following pattern:
+ // Sh0 (Sh1 X, Q), K
+ // We want to rewrite that as:
+ // Sh x, (Q+K) iff (Q+K) u< bitwidth(x)
+ // While we know that originally (Q+K) would not overflow
+ // (because 2 * (N-1) u<= iN -1), we have looked past extensions of
+ // shift amounts. so it may now overflow in smaller bitwidth.
+ // To ensure that does not happen, we need to ensure that the total maximal
+ // shift amount is still representable in that smaller bit width.
+ unsigned MaximalPossibleTotalShiftAmount =
+ (Sh0->getType()->getScalarSizeInBits() - 1) +
+ (Sh1->getType()->getScalarSizeInBits() - 1);
+ APInt MaximalRepresentableShiftAmount =
+ APInt::getAllOnesValue(ShAmt0->getType()->getScalarSizeInBits());
+ if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount))
+ return nullptr;
+
// We are only looking for signbit extraction if we have two right shifts.
bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) &&
match(Sh1, m_Shr(m_Value(), m_Value()));