summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp163
1 files changed, 108 insertions, 55 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 89dad455f015..b7958978c450 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -136,9 +136,14 @@ Value *InstCombinerImpl::reassociateShiftAmtsOfTwoSameDirectionShifts(
assert(IdenticalShOpcodes && "Should not get here with different shifts.");
- // All good, we can do this fold.
- NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, X->getType());
+ if (NewShAmt->getType() != X->getType()) {
+ NewShAmt = ConstantFoldCastOperand(Instruction::ZExt, NewShAmt,
+ X->getType(), SQ.DL);
+ if (!NewShAmt)
+ return nullptr;
+ }
+ // All good, we can do this fold.
BinaryOperator *NewShift = BinaryOperator::Create(ShiftOpcode, X, NewShAmt);
// The flags can only be propagated if there wasn't a trunc.
@@ -245,7 +250,11 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
SumOfShAmts = Constant::replaceUndefsWith(
SumOfShAmts, ConstantInt::get(SumOfShAmts->getType()->getScalarType(),
ExtendedTy->getScalarSizeInBits()));
- auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy);
+ auto *ExtendedSumOfShAmts = ConstantFoldCastOperand(
+ Instruction::ZExt, SumOfShAmts, ExtendedTy, Q.DL);
+ if (!ExtendedSumOfShAmts)
+ return nullptr;
+
// And compute the mask as usual: ~(-1 << (SumOfShAmts))
auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
auto *ExtendedInvertedMask =
@@ -278,16 +287,22 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
ShAmtsDiff = Constant::replaceUndefsWith(
ShAmtsDiff, ConstantInt::get(ShAmtsDiff->getType()->getScalarType(),
-WidestTyBitWidth));
- auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt(
+ auto *ExtendedNumHighBitsToClear = ConstantFoldCastOperand(
+ Instruction::ZExt,
ConstantExpr::getSub(ConstantInt::get(ShAmtsDiff->getType(),
WidestTyBitWidth,
/*isSigned=*/false),
ShAmtsDiff),
- ExtendedTy);
+ ExtendedTy, Q.DL);
+ if (!ExtendedNumHighBitsToClear)
+ return nullptr;
+
// And compute the mask as usual: (-1 l>> (NumHighBitsToClear))
auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
- NewMask =
- ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear);
+ NewMask = ConstantFoldBinaryOpOperands(Instruction::LShr, ExtendedAllOnes,
+ ExtendedNumHighBitsToClear, Q.DL);
+ if (!NewMask)
+ return nullptr;
} else
return nullptr; // Don't know anything about this pattern.
@@ -545,8 +560,8 @@ static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl,
/// this succeeds, getShiftedValue() will be called to produce the value.
static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
InstCombinerImpl &IC, Instruction *CxtI) {
- // We can always evaluate constants shifted.
- if (isa<Constant>(V))
+ // We can always evaluate immediate constants.
+ if (match(V, m_ImmConstant()))
return true;
Instruction *I = dyn_cast<Instruction>(V);
@@ -709,13 +724,13 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
case Instruction::Mul: {
assert(!isLeftShift && "Unexpected shift direction!");
auto *Neg = BinaryOperator::CreateNeg(I->getOperand(0));
- IC.InsertNewInstWith(Neg, *I);
+ IC.InsertNewInstWith(Neg, I->getIterator());
unsigned TypeWidth = I->getType()->getScalarSizeInBits();
APInt Mask = APInt::getLowBitsSet(TypeWidth, TypeWidth - NumBits);
auto *And = BinaryOperator::CreateAnd(Neg,
ConstantInt::get(I->getType(), Mask));
And->takeName(I);
- return IC.InsertNewInstWith(And, *I);
+ return IC.InsertNewInstWith(And, I->getIterator());
}
}
}
@@ -745,7 +760,7 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1,
// (C2 >> X) >> C1 --> (C2 >> C1) >> X
Constant *C2;
Value *X;
- if (match(Op0, m_BinOp(I.getOpcode(), m_Constant(C2), m_Value(X))))
+ if (match(Op0, m_BinOp(I.getOpcode(), m_ImmConstant(C2), m_Value(X))))
return BinaryOperator::Create(
I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), C2, C1), X);
@@ -928,6 +943,60 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) {
return new ZExtInst(Overflow, Ty);
}
+// Try to set nuw/nsw flags on shl or exact flag on lshr/ashr using knownbits.
+static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) {
+ assert(I.isShift() && "Expected a shift as input");
+ // We already have all the flags.
+ if (I.getOpcode() == Instruction::Shl) {
+ if (I.hasNoUnsignedWrap() && I.hasNoSignedWrap())
+ return false;
+ } else {
+ if (I.isExact())
+ return false;
+
+ // shr (shl X, Y), Y
+ if (match(I.getOperand(0), m_Shl(m_Value(), m_Specific(I.getOperand(1))))) {
+ I.setIsExact();
+ return true;
+ }
+ }
+
+ // Compute what we know about shift count.
+ KnownBits KnownCnt = computeKnownBits(I.getOperand(1), /* Depth */ 0, Q);
+ unsigned BitWidth = KnownCnt.getBitWidth();
+ // Since shift produces a poison value if RHS is equal to or larger than the
+ // bit width, we can safely assume that RHS is less than the bit width.
+ uint64_t MaxCnt = KnownCnt.getMaxValue().getLimitedValue(BitWidth - 1);
+
+ KnownBits KnownAmt = computeKnownBits(I.getOperand(0), /* Depth */ 0, Q);
+ bool Changed = false;
+
+ if (I.getOpcode() == Instruction::Shl) {
+ // If we have as many leading zeros than maximum shift cnt we have nuw.
+ if (!I.hasNoUnsignedWrap() && MaxCnt <= KnownAmt.countMinLeadingZeros()) {
+ I.setHasNoUnsignedWrap();
+ Changed = true;
+ }
+ // If we have more sign bits than maximum shift cnt we have nsw.
+ if (!I.hasNoSignedWrap()) {
+ if (MaxCnt < KnownAmt.countMinSignBits() ||
+ MaxCnt < ComputeNumSignBits(I.getOperand(0), Q.DL, /*Depth*/ 0, Q.AC,
+ Q.CxtI, Q.DT)) {
+ I.setHasNoSignedWrap();
+ Changed = true;
+ }
+ }
+ return Changed;
+ }
+
+ // If we have at least as many trailing zeros as maximum count then we have
+ // exact.
+ Changed = MaxCnt <= KnownAmt.countMinTrailingZeros();
+ I.setIsExact(Changed);
+
+ return Changed;
+}
+
Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
const SimplifyQuery Q = SQ.getWithInstruction(&I);
@@ -976,7 +1045,11 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
// If C1 < C: (X >>?,exact C1) << C --> X << (C - C1)
Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt);
auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
- NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
+ NewShl->setHasNoUnsignedWrap(
+ I.hasNoUnsignedWrap() ||
+ (ShrAmt &&
+ cast<Instruction>(Op0)->getOpcode() == Instruction::LShr &&
+ I.hasNoSignedWrap()));
NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
return NewShl;
}
@@ -997,7 +1070,11 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
// If C1 < C: (X >>? C1) << C --> (X << (C - C1)) & (-1 << C)
Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt);
auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
- NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
+ NewShl->setHasNoUnsignedWrap(
+ I.hasNoUnsignedWrap() ||
+ (ShrAmt &&
+ cast<Instruction>(Op0)->getOpcode() == Instruction::LShr &&
+ I.hasNoSignedWrap()));
NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
Builder.Insert(NewShl);
APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC));
@@ -1108,22 +1185,11 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
Value *NewShift = Builder.CreateShl(X, Op1);
return BinaryOperator::CreateSub(NewLHS, NewShift);
}
-
- // If the shifted-out value is known-zero, then this is a NUW shift.
- if (!I.hasNoUnsignedWrap() &&
- MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmtC), 0,
- &I)) {
- I.setHasNoUnsignedWrap();
- return &I;
- }
-
- // If the shifted-out value is all signbits, then this is a NSW shift.
- if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmtC) {
- I.setHasNoSignedWrap();
- return &I;
- }
}
+ if (setShiftFlags(I, Q))
+ return &I;
+
// Transform (x >> y) << y to x & (-1 << y)
// Valid for any type of right-shift.
Value *X;
@@ -1161,15 +1227,6 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
Value *NegX = Builder.CreateNeg(X, "neg");
return BinaryOperator::CreateAnd(NegX, X);
}
-
- // The only way to shift out the 1 is with an over-shift, so that would
- // be poison with or without "nuw". Undef is excluded because (undef << X)
- // is not undef (it is zero).
- Constant *ConstantOne = cast<Constant>(Op0);
- if (!I.hasNoUnsignedWrap() && !ConstantOne->containsUndefElement()) {
- I.setHasNoUnsignedWrap();
- return &I;
- }
}
return nullptr;
@@ -1235,9 +1292,10 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
unsigned ShlAmtC = C1->getZExtValue();
Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmtC - ShAmtC);
if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
- // (X <<nuw C1) >>u C --> X <<nuw (C1 - C)
+ // (X <<nuw C1) >>u C --> X <<nuw/nsw (C1 - C)
auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
NewShl->setHasNoUnsignedWrap(true);
+ NewShl->setHasNoSignedWrap(ShAmtC > 0);
return NewShl;
}
if (Op0->hasOneUse()) {
@@ -1370,12 +1428,13 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
if (Op0->hasOneUse()) {
APInt NewMulC = MulC->lshr(ShAmtC);
// if c is divisible by (1 << ShAmtC):
- // lshr (mul nuw x, MulC), ShAmtC -> mul nuw x, (MulC >> ShAmtC)
+ // lshr (mul nuw x, MulC), ShAmtC -> mul nuw nsw x, (MulC >> ShAmtC)
if (MulC->eq(NewMulC.shl(ShAmtC))) {
auto *NewMul =
BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC));
- BinaryOperator *OrigMul = cast<BinaryOperator>(Op0);
- NewMul->setHasNoSignedWrap(OrigMul->hasNoSignedWrap());
+ assert(ShAmtC != 0 &&
+ "lshr X, 0 should be handled by simplifyLShrInst.");
+ NewMul->setHasNoSignedWrap(true);
return NewMul;
}
}
@@ -1414,15 +1473,12 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
Value *And = Builder.CreateAnd(BoolX, BoolY);
return new ZExtInst(And, Ty);
}
-
- // If the shifted-out value is known-zero, then this is an exact shift.
- if (!I.isExact() &&
- MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmtC), 0, &I)) {
- I.setIsExact();
- return &I;
- }
}
+ const SimplifyQuery Q = SQ.getWithInstruction(&I);
+ if (setShiftFlags(I, Q))
+ return &I;
+
// Transform (x << y) >> y to x & (-1 >> y)
if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) {
Constant *AllOnes = ConstantInt::getAllOnesValue(Ty);
@@ -1581,15 +1637,12 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y)))))
return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty);
}
-
- // If the shifted-out value is known-zero, then this is an exact shift.
- if (!I.isExact() &&
- MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
- I.setIsExact();
- return &I;
- }
}
+ const SimplifyQuery Q = SQ.getWithInstruction(&I);
+ if (setShiftFlags(I, Q))
+ return &I;
+
// Prefer `-(x & 1)` over `(x << (bitwidth(x)-1)) a>> (bitwidth(x)-1)`
// as the pattern to splat the lowest bit.
// FIXME: iff X is already masked, we don't need the one-use check.