aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp157
1 files changed, 136 insertions, 21 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 13c98b935adf..ec505381cc86 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -346,8 +346,8 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I,
Value *X, *Y;
auto matchFirstShift = [&](Value *V) {
APInt Threshold(Ty->getScalarSizeInBits(), Ty->getScalarSizeInBits());
- return match(V, m_BinOp(ShiftOpcode, m_Value(), m_Value())) &&
- match(V, m_OneUse(m_Shift(m_Value(X), m_Constant(C0)))) &&
+ return match(V,
+ m_OneUse(m_BinOp(ShiftOpcode, m_Value(X), m_Constant(C0)))) &&
match(ConstantExpr::getAdd(C0, C1),
m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold));
};
@@ -363,7 +363,7 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I,
// shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
Constant *ShiftSumC = ConstantExpr::getAdd(C0, C1);
Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC);
- Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, I.getOperand(1));
+ Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, C1);
return BinaryOperator::Create(LogicInst->getOpcode(), NewShift1, NewShift2);
}
@@ -730,13 +730,34 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1,
return BinaryOperator::Create(
I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), C2, C1), X);
+ bool IsLeftShift = I.getOpcode() == Instruction::Shl;
+ Type *Ty = I.getType();
+ unsigned TypeBits = Ty->getScalarSizeInBits();
+
+ // (X / +DivC) >> (Width - 1) --> ext (X <= -DivC)
+ // (X / -DivC) >> (Width - 1) --> ext (X >= +DivC)
+ const APInt *DivC;
+ if (!IsLeftShift && match(C1, m_SpecificIntAllowUndef(TypeBits - 1)) &&
+ match(Op0, m_SDiv(m_Value(X), m_APInt(DivC))) && !DivC->isZero() &&
+ !DivC->isMinSignedValue()) {
+ Constant *NegDivC = ConstantInt::get(Ty, -(*DivC));
+ ICmpInst::Predicate Pred =
+ DivC->isNegative() ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_SLE;
+ Value *Cmp = Builder.CreateICmp(Pred, X, NegDivC);
+ auto ExtOpcode = (I.getOpcode() == Instruction::AShr) ? Instruction::SExt
+ : Instruction::ZExt;
+ return CastInst::Create(ExtOpcode, Cmp, Ty);
+ }
+
const APInt *Op1C;
if (!match(C1, m_APInt(Op1C)))
return nullptr;
+ assert(!Op1C->uge(TypeBits) &&
+ "Shift over the type width should have been removed already");
+
// See if we can propagate this shift into the input, this covers the trivial
// cast of lshr(shl(x,c1),c2) as well as other more complex cases.
- bool IsLeftShift = I.getOpcode() == Instruction::Shl;
if (I.getOpcode() != Instruction::AShr &&
canEvaluateShifted(Op0, Op1C->getZExtValue(), IsLeftShift, *this, &I)) {
LLVM_DEBUG(
@@ -748,14 +769,6 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1,
I, getShiftedValue(Op0, Op1C->getZExtValue(), IsLeftShift, *this, DL));
}
- // See if we can simplify any instructions used by the instruction whose sole
- // purpose is to compute bits we don't care about.
- Type *Ty = I.getType();
- unsigned TypeBits = Ty->getScalarSizeInBits();
- assert(!Op1C->uge(TypeBits) &&
- "Shift over the type width should have been removed already");
- (void)TypeBits;
-
if (Instruction *FoldedShift = foldBinOpIntoSelectOrPhi(I))
return FoldedShift;
@@ -826,6 +839,74 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1,
return nullptr;
}
+// Tries to perform
+// (lshr (add (zext X), (zext Y)), K)
+// -> (icmp ult (add X, Y), X)
+// where
+// - The add's operands are zexts from a K-bits integer to a bigger type.
+// - The add is only used by the shr, or by iK (or narrower) truncates.
+// - The lshr type has more than 2 bits (other types are boolean math).
+// - K > 1
+// note that
+// - The resulting add cannot have nuw/nsw, else on overflow we get a
+// poison value and the transform isn't legal anymore.
+Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) {
+ assert(I.getOpcode() == Instruction::LShr);
+
+ Value *Add = I.getOperand(0);
+ Value *ShiftAmt = I.getOperand(1);
+ Type *Ty = I.getType();
+
+ if (Ty->getScalarSizeInBits() < 3)
+ return nullptr;
+
+ const APInt *ShAmtAPInt = nullptr;
+ Value *X = nullptr, *Y = nullptr;
+ if (!match(ShiftAmt, m_APInt(ShAmtAPInt)) ||
+ !match(Add,
+ m_Add(m_OneUse(m_ZExt(m_Value(X))), m_OneUse(m_ZExt(m_Value(Y))))))
+ return nullptr;
+
+ const unsigned ShAmt = ShAmtAPInt->getZExtValue();
+ if (ShAmt == 1)
+ return nullptr;
+
+ // X/Y are zexts from `ShAmt`-sized ints.
+ if (X->getType()->getScalarSizeInBits() != ShAmt ||
+ Y->getType()->getScalarSizeInBits() != ShAmt)
+ return nullptr;
+
+ // Make sure that `Add` is only used by `I` and `ShAmt`-truncates.
+ if (!Add->hasOneUse()) {
+ for (User *U : Add->users()) {
+ if (U == &I)
+ continue;
+
+ TruncInst *Trunc = dyn_cast<TruncInst>(U);
+ if (!Trunc || Trunc->getType()->getScalarSizeInBits() > ShAmt)
+ return nullptr;
+ }
+ }
+
+ // Insert at Add so that the newly created `NarrowAdd` will dominate it's
+ // users (i.e. `Add`'s users).
+ Instruction *AddInst = cast<Instruction>(Add);
+ Builder.SetInsertPoint(AddInst);
+
+ Value *NarrowAdd = Builder.CreateAdd(X, Y, "add.narrowed");
+ Value *Overflow =
+ Builder.CreateICmpULT(NarrowAdd, X, "add.narrowed.overflow");
+
+ // Replace the uses of the original add with a zext of the
+ // NarrowAdd's result. Note that all users at this stage are known to
+ // be ShAmt-sized truncs, or the lshr itself.
+ if (!Add->hasOneUse())
+ replaceInstUsesWith(*AddInst, Builder.CreateZExt(NarrowAdd, Ty));
+
+ // Replace the LShr with a zext of the overflow check.
+ return new ZExtInst(Overflow, Ty);
+}
+
Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
const SimplifyQuery Q = SQ.getWithInstruction(&I);
@@ -1046,11 +1127,21 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
}
}
- // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1
- if (match(Op0, m_One()) &&
- match(Op1, m_Sub(m_SpecificInt(BitWidth - 1), m_Value(X))))
- return BinaryOperator::CreateLShr(
- ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X);
+ if (match(Op0, m_One())) {
+ // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1
+ if (match(Op1, m_Sub(m_SpecificInt(BitWidth - 1), m_Value(X))))
+ return BinaryOperator::CreateLShr(
+ ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X);
+
+ // The only way to shift out the 1 is with an over-shift, so that would
+ // be poison with or without "nuw". Undef is excluded because (undef << X)
+ // is not undef (it is zero).
+ Constant *ConstantOne = cast<Constant>(Op0);
+ if (!I.hasNoUnsignedWrap() && !ConstantOne->containsUndefElement()) {
+ I.setHasNoUnsignedWrap();
+ return &I;
+ }
+ }
return nullptr;
}
@@ -1068,10 +1159,17 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Type *Ty = I.getType();
+ Value *X;
const APInt *C;
+ unsigned BitWidth = Ty->getScalarSizeInBits();
+
+ // (iN (~X) u>> (N - 1)) --> zext (X > -1)
+ if (match(Op0, m_OneUse(m_Not(m_Value(X)))) &&
+ match(Op1, m_SpecificIntAllowUndef(BitWidth - 1)))
+ return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty);
+
if (match(Op1, m_APInt(C))) {
unsigned ShAmtC = C->getZExtValue();
- unsigned BitWidth = Ty->getScalarSizeInBits();
auto *II = dyn_cast<IntrinsicInst>(Op0);
if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmtC &&
(II->getIntrinsicID() == Intrinsic::ctlz ||
@@ -1276,6 +1374,18 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
}
}
+ // Reduce add-carry of bools to logic:
+ // ((zext BoolX) + (zext BoolY)) >> 1 --> zext (BoolX && BoolY)
+ Value *BoolX, *BoolY;
+ if (ShAmtC == 1 && match(Op0, m_Add(m_Value(X), m_Value(Y))) &&
+ match(X, m_ZExt(m_Value(BoolX))) && match(Y, m_ZExt(m_Value(BoolY))) &&
+ BoolX->getType()->isIntOrIntVectorTy(1) &&
+ BoolY->getType()->isIntOrIntVectorTy(1) &&
+ (X->hasOneUse() || Y->hasOneUse() || Op0->hasOneUse())) {
+ Value *And = Builder.CreateAnd(BoolX, BoolY);
+ return new ZExtInst(And, Ty);
+ }
+
// If the shifted-out value is known-zero, then this is an exact shift.
if (!I.isExact() &&
MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmtC), 0, &I)) {
@@ -1285,13 +1395,15 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
}
// Transform (x << y) >> y to x & (-1 >> y)
- Value *X;
if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) {
Constant *AllOnes = ConstantInt::getAllOnesValue(Ty);
Value *Mask = Builder.CreateLShr(AllOnes, Op1);
return BinaryOperator::CreateAnd(Mask, X);
}
+ if (Instruction *Overflow = foldLShrOverflowBit(I))
+ return Overflow;
+
return nullptr;
}
@@ -1469,8 +1581,11 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
return R;
// See if we can turn a signed shr into an unsigned shr.
- if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I))
- return BinaryOperator::CreateLShr(Op0, Op1);
+ if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I)) {
+ Instruction *Lshr = BinaryOperator::CreateLShr(Op0, Op1);
+ Lshr->setIsExact(I.isExact());
+ return Lshr;
+ }
// ashr (xor %x, -1), %y --> xor (ashr %x, %y), -1
if (match(Op0, m_OneUse(m_Not(m_Value(X))))) {