summaryrefslogtreecommitdiff
path: root/lib/Transforms/InstCombine/InstCombineShifts.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineShifts.cpp')
-rw-r--r--lib/Transforms/InstCombine/InstCombineShifts.cpp703
1 files changed, 301 insertions, 402 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 4ff9b64ac57cf..9aa679c60e47b 100644
--- a/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -22,8 +22,8 @@ using namespace PatternMatch;
#define DEBUG_TYPE "instcombine"
Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
- assert(I.getOperand(1)->getType() == I.getOperand(0)->getType());
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ assert(Op0->getType() == Op1->getType());
// See if we can fold away this shift.
if (SimplifyDemandedInstructionBits(I))
@@ -65,63 +65,60 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
}
/// Return true if we can simplify two logical (either left or right) shifts
-/// that have constant shift amounts.
-static bool canEvaluateShiftedShift(unsigned FirstShiftAmt,
- bool IsFirstShiftLeft,
- Instruction *SecondShift, InstCombiner &IC,
+/// that have constant shift amounts: OuterShift (InnerShift X, C1), C2.
+static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl,
+ Instruction *InnerShift, InstCombiner &IC,
Instruction *CxtI) {
- assert(SecondShift->isLogicalShift() && "Unexpected instruction type");
+ assert(InnerShift->isLogicalShift() && "Unexpected instruction type");
- // We need constant shifts.
- auto *SecondShiftConst = dyn_cast<ConstantInt>(SecondShift->getOperand(1));
- if (!SecondShiftConst)
+ // We need constant scalar or constant splat shifts.
+ const APInt *InnerShiftConst;
+ if (!match(InnerShift->getOperand(1), m_APInt(InnerShiftConst)))
return false;
- unsigned SecondShiftAmt = SecondShiftConst->getZExtValue();
- bool IsSecondShiftLeft = SecondShift->getOpcode() == Instruction::Shl;
-
- // We can always fold shl(c1) + shl(c2) -> shl(c1+c2).
- // We can always fold lshr(c1) + lshr(c2) -> lshr(c1+c2).
- if (IsFirstShiftLeft == IsSecondShiftLeft)
+ // Two logical shifts in the same direction:
+ // shl (shl X, C1), C2 --> shl X, C1 + C2
+ // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
+ bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
+ if (IsInnerShl == IsOuterShl)
return true;
- // We can always fold lshr(c) + shl(c) -> and(c2).
- // We can always fold shl(c) + lshr(c) -> and(c2).
- if (FirstShiftAmt == SecondShiftAmt)
+ // Equal shift amounts in opposite directions become bitwise 'and':
+ // lshr (shl X, C), C --> and X, C'
+ // shl (lshr X, C), C --> and X, C'
+ unsigned InnerShAmt = InnerShiftConst->getZExtValue();
+ if (InnerShAmt == OuterShAmt)
return true;
- unsigned TypeWidth = SecondShift->getType()->getScalarSizeInBits();
-
// If the 2nd shift is bigger than the 1st, we can fold:
- // lshr(c1) + shl(c2) -> shl(c3) + and(c4) or
- // shl(c1) + lshr(c2) -> lshr(c3) + and(c4),
+ // lshr (shl X, C1), C2 --> and (shl X, C1 - C2), C3
+ // shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3
// but it isn't profitable unless we know the and'd out bits are already zero.
- // Also check that the 2nd shift is valid (less than the type width) or we'll
- // crash trying to produce the bit mask for the 'and'.
- if (SecondShiftAmt > FirstShiftAmt && SecondShiftAmt < TypeWidth) {
- unsigned MaskShift = IsSecondShiftLeft ? TypeWidth - SecondShiftAmt
- : SecondShiftAmt - FirstShiftAmt;
- APInt Mask = APInt::getLowBitsSet(TypeWidth, FirstShiftAmt) << MaskShift;
- if (IC.MaskedValueIsZero(SecondShift->getOperand(0), Mask, 0, CxtI))
+ // Also, check that the inner shift is valid (less than the type width) or
+ // we'll crash trying to produce the bit mask for the 'and'.
+ unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits();
+ if (InnerShAmt > OuterShAmt && InnerShAmt < TypeWidth) {
+ unsigned MaskShift =
+ IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt;
+ APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift;
+ if (IC.MaskedValueIsZero(InnerShift->getOperand(0), Mask, 0, CxtI))
return true;
}
return false;
}
-/// See if we can compute the specified value, but shifted
-/// logically to the left or right by some number of bits. This should return
-/// true if the expression can be computed for the same cost as the current
-/// expression tree. This is used to eliminate extraneous shifting from things
-/// like:
+/// See if we can compute the specified value, but shifted logically to the left
+/// or right by some number of bits. This should return true if the expression
+/// can be computed for the same cost as the current expression tree. This is
+/// used to eliminate extraneous shifting from things like:
/// %C = shl i128 %A, 64
/// %D = shl i128 %B, 96
/// %E = or i128 %C, %D
/// %F = lshr i128 %E, 64
-/// where the client will ask if E can be computed shifted right by 64-bits. If
-/// this succeeds, the GetShiftedValue function will be called to produce the
-/// value.
-static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
+/// where the client will ask if E can be computed shifted right by 64-bits. If
+/// this succeeds, getShiftedValue() will be called to produce the value.
+static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
InstCombiner &IC, Instruction *CxtI) {
// We can always evaluate constants shifted.
if (isa<Constant>(V))
@@ -165,8 +162,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
case Instruction::Or:
case Instruction::Xor:
// Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
- return CanEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) &&
- CanEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I);
+ return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) &&
+ canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I);
case Instruction::Shl:
case Instruction::LShr:
@@ -176,8 +173,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
SelectInst *SI = cast<SelectInst>(I);
Value *TrueVal = SI->getTrueValue();
Value *FalseVal = SI->getFalseValue();
- return CanEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) &&
- CanEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI);
+ return canEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) &&
+ canEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI);
}
case Instruction::PHI: {
// We can change a phi if we can change all operands. Note that we never
@@ -185,16 +182,79 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
// instructions with a single use.
PHINode *PN = cast<PHINode>(I);
for (Value *IncValue : PN->incoming_values())
- if (!CanEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN))
+ if (!canEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN))
return false;
return true;
}
}
}
-/// When CanEvaluateShifted returned true for an expression,
-/// this value inserts the new computation that produces the shifted value.
-static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
+/// Fold OuterShift (InnerShift X, C1), C2.
+/// See canEvaluateShiftedShift() for the constraints on these instructions.
+static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt,
+ bool IsOuterShl,
+ InstCombiner::BuilderTy &Builder) {
+ bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
+ Type *ShType = InnerShift->getType();
+ unsigned TypeWidth = ShType->getScalarSizeInBits();
+
+ // We only accept shifts-by-a-constant in canEvaluateShifted().
+ const APInt *C1;
+ match(InnerShift->getOperand(1), m_APInt(C1));
+ unsigned InnerShAmt = C1->getZExtValue();
+
+ // Change the shift amount and clear the appropriate IR flags.
+ auto NewInnerShift = [&](unsigned ShAmt) {
+ InnerShift->setOperand(1, ConstantInt::get(ShType, ShAmt));
+ if (IsInnerShl) {
+ InnerShift->setHasNoUnsignedWrap(false);
+ InnerShift->setHasNoSignedWrap(false);
+ } else {
+ InnerShift->setIsExact(false);
+ }
+ return InnerShift;
+ };
+
+ // Two logical shifts in the same direction:
+ // shl (shl X, C1), C2 --> shl X, C1 + C2
+ // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
+ if (IsInnerShl == IsOuterShl) {
+ // If this is an oversized composite shift, then unsigned shifts get 0.
+ if (InnerShAmt + OuterShAmt >= TypeWidth)
+ return Constant::getNullValue(ShType);
+
+ return NewInnerShift(InnerShAmt + OuterShAmt);
+ }
+
+ // Equal shift amounts in opposite directions become bitwise 'and':
+ // lshr (shl X, C), C --> and X, C'
+ // shl (lshr X, C), C --> and X, C'
+ if (InnerShAmt == OuterShAmt) {
+ APInt Mask = IsInnerShl
+ ? APInt::getLowBitsSet(TypeWidth, TypeWidth - OuterShAmt)
+ : APInt::getHighBitsSet(TypeWidth, TypeWidth - OuterShAmt);
+ Value *And = Builder.CreateAnd(InnerShift->getOperand(0),
+ ConstantInt::get(ShType, Mask));
+ if (auto *AndI = dyn_cast<Instruction>(And)) {
+ AndI->moveBefore(InnerShift);
+ AndI->takeName(InnerShift);
+ }
+ return And;
+ }
+
+ assert(InnerShAmt > OuterShAmt &&
+ "Unexpected opposite direction logical shift pair");
+
+ // In general, we would need an 'and' for this transform, but
+ // canEvaluateShiftedShift() guarantees that the masked-off bits are not used.
+ // lshr (shl X, C1), C2 --> shl X, C1 - C2
+ // shl (lshr X, C1), C2 --> lshr X, C1 - C2
+ return NewInnerShift(InnerShAmt - OuterShAmt);
+}
+
+/// When canEvaluateShifted() returns true for an expression, this function
+/// inserts the new computation that produces the shifted value.
+static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
InstCombiner &IC, const DataLayout &DL) {
// We can always evaluate constants shifted.
if (Constant *C = dyn_cast<Constant>(V)) {
@@ -220,100 +280,21 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
case Instruction::Xor:
// Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
I->setOperand(
- 0, GetShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL));
+ 0, getShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL));
I->setOperand(
- 1, GetShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
+ 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
return I;
- case Instruction::Shl: {
- BinaryOperator *BO = cast<BinaryOperator>(I);
- unsigned TypeWidth = BO->getType()->getScalarSizeInBits();
-
- // We only accept shifts-by-a-constant in CanEvaluateShifted.
- ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
-
- // We can always fold shl(c1)+shl(c2) -> shl(c1+c2).
- if (isLeftShift) {
- // If this is oversized composite shift, then unsigned shifts get 0.
- unsigned NewShAmt = NumBits+CI->getZExtValue();
- if (NewShAmt >= TypeWidth)
- return Constant::getNullValue(I->getType());
-
- BO->setOperand(1, ConstantInt::get(BO->getType(), NewShAmt));
- BO->setHasNoUnsignedWrap(false);
- BO->setHasNoSignedWrap(false);
- return I;
- }
-
- // We turn shl(c)+lshr(c) -> and(c2) if the input doesn't already have
- // zeros.
- if (CI->getValue() == NumBits) {
- APInt Mask(APInt::getLowBitsSet(TypeWidth, TypeWidth - NumBits));
- V = IC.Builder->CreateAnd(BO->getOperand(0),
- ConstantInt::get(BO->getContext(), Mask));
- if (Instruction *VI = dyn_cast<Instruction>(V)) {
- VI->moveBefore(BO);
- VI->takeName(BO);
- }
- return V;
- }
-
- // We turn shl(c1)+shr(c2) -> shl(c3)+and(c4), but only when we know that
- // the and won't be needed.
- assert(CI->getZExtValue() > NumBits);
- BO->setOperand(1, ConstantInt::get(BO->getType(),
- CI->getZExtValue() - NumBits));
- BO->setHasNoUnsignedWrap(false);
- BO->setHasNoSignedWrap(false);
- return BO;
- }
- // FIXME: This is almost identical to the SHL case. Refactor both cases into
- // a helper function.
- case Instruction::LShr: {
- BinaryOperator *BO = cast<BinaryOperator>(I);
- unsigned TypeWidth = BO->getType()->getScalarSizeInBits();
- // We only accept shifts-by-a-constant in CanEvaluateShifted.
- ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
-
- // We can always fold lshr(c1)+lshr(c2) -> lshr(c1+c2).
- if (!isLeftShift) {
- // If this is oversized composite shift, then unsigned shifts get 0.
- unsigned NewShAmt = NumBits+CI->getZExtValue();
- if (NewShAmt >= TypeWidth)
- return Constant::getNullValue(BO->getType());
-
- BO->setOperand(1, ConstantInt::get(BO->getType(), NewShAmt));
- BO->setIsExact(false);
- return I;
- }
-
- // We turn lshr(c)+shl(c) -> and(c2) if the input doesn't already have
- // zeros.
- if (CI->getValue() == NumBits) {
- APInt Mask(APInt::getHighBitsSet(TypeWidth, TypeWidth - NumBits));
- V = IC.Builder->CreateAnd(I->getOperand(0),
- ConstantInt::get(BO->getContext(), Mask));
- if (Instruction *VI = dyn_cast<Instruction>(V)) {
- VI->moveBefore(I);
- VI->takeName(I);
- }
- return V;
- }
-
- // We turn lshr(c1)+shl(c2) -> lshr(c3)+and(c4), but only when we know that
- // the and won't be needed.
- assert(CI->getZExtValue() > NumBits);
- BO->setOperand(1, ConstantInt::get(BO->getType(),
- CI->getZExtValue() - NumBits));
- BO->setIsExact(false);
- return BO;
- }
+ case Instruction::Shl:
+ case Instruction::LShr:
+ return foldShiftedShift(cast<BinaryOperator>(I), NumBits, isLeftShift,
+ *(IC.Builder));
case Instruction::Select:
I->setOperand(
- 1, GetShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
+ 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
I->setOperand(
- 2, GetShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL));
+ 2, getShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL));
return I;
case Instruction::PHI: {
// We can change a phi if we can change all operands. Note that we never
@@ -321,215 +302,39 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
// instructions with a single use.
PHINode *PN = cast<PHINode>(I);
for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
- PN->setIncomingValue(i, GetShiftedValue(PN->getIncomingValue(i), NumBits,
+ PN->setIncomingValue(i, getShiftedValue(PN->getIncomingValue(i), NumBits,
isLeftShift, IC, DL));
return PN;
}
}
}
-/// Try to fold (X << C1) << C2, where the shifts are some combination of
-/// shl/ashr/lshr.
-static Instruction *
-foldShiftByConstOfShiftByConst(BinaryOperator &I, ConstantInt *COp1,
- InstCombiner::BuilderTy *Builder) {
- Value *Op0 = I.getOperand(0);
- uint32_t TypeBits = Op0->getType()->getScalarSizeInBits();
-
- // Find out if this is a shift of a shift by a constant.
- BinaryOperator *ShiftOp = dyn_cast<BinaryOperator>(Op0);
- if (ShiftOp && !ShiftOp->isShift())
- ShiftOp = nullptr;
-
- if (ShiftOp && isa<ConstantInt>(ShiftOp->getOperand(1))) {
-
- // This is a constant shift of a constant shift. Be careful about hiding
- // shl instructions behind bit masks. They are used to represent multiplies
- // by a constant, and it is important that simple arithmetic expressions
- // are still recognizable by scalar evolution.
- //
- // The transforms applied to shl are very similar to the transforms applied
- // to mul by constant. We can be more aggressive about optimizing right
- // shifts.
- //
- // Combinations of right and left shifts will still be optimized in
- // DAGCombine where scalar evolution no longer applies.
-
- ConstantInt *ShiftAmt1C = cast<ConstantInt>(ShiftOp->getOperand(1));
- uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits);
- uint32_t ShiftAmt2 = COp1->getLimitedValue(TypeBits);
- assert(ShiftAmt2 != 0 && "Should have been simplified earlier");
- if (ShiftAmt1 == 0)
- return nullptr; // Will be simplified in the future.
- Value *X = ShiftOp->getOperand(0);
-
- IntegerType *Ty = cast<IntegerType>(I.getType());
-
- // Check for (X << c1) << c2 and (X >> c1) >> c2
- if (I.getOpcode() == ShiftOp->getOpcode()) {
- uint32_t AmtSum = ShiftAmt1 + ShiftAmt2; // Fold into one big shift.
- // If this is an oversized composite shift, then unsigned shifts become
- // zero (handled in InstSimplify) and ashr saturates.
- if (AmtSum >= TypeBits) {
- if (I.getOpcode() != Instruction::AShr)
- return nullptr;
- AmtSum = TypeBits - 1; // Saturate to 31 for i32 ashr.
- }
-
- return BinaryOperator::Create(I.getOpcode(), X,
- ConstantInt::get(Ty, AmtSum));
- }
-
- if (ShiftAmt1 == ShiftAmt2) {
- // If we have ((X << C) >>u C), turn this into X & (-1 >>u C).
- if (I.getOpcode() == Instruction::LShr &&
- ShiftOp->getOpcode() == Instruction::Shl) {
- APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt1));
- return BinaryOperator::CreateAnd(
- X, ConstantInt::get(I.getContext(), Mask));
- }
- } else if (ShiftAmt1 < ShiftAmt2) {
- uint32_t ShiftDiff = ShiftAmt2 - ShiftAmt1;
-
- // (X >>?,exact C1) << C2 --> X << (C2-C1)
- // The inexact version is deferred to DAGCombine so we don't hide shl
- // behind a bit mask.
- if (I.getOpcode() == Instruction::Shl &&
- ShiftOp->getOpcode() != Instruction::Shl && ShiftOp->isExact()) {
- assert(ShiftOp->getOpcode() == Instruction::LShr ||
- ShiftOp->getOpcode() == Instruction::AShr);
- ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
- BinaryOperator *NewShl =
- BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst);
- NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
- NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
- return NewShl;
- }
-
- // (X << C1) >>u C2 --> X >>u (C2-C1) & (-1 >> C2)
- if (I.getOpcode() == Instruction::LShr &&
- ShiftOp->getOpcode() == Instruction::Shl) {
- ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
- // (X <<nuw C1) >>u C2 --> X >>u (C2-C1)
- if (ShiftOp->hasNoUnsignedWrap()) {
- BinaryOperator *NewLShr =
- BinaryOperator::Create(Instruction::LShr, X, ShiftDiffCst);
- NewLShr->setIsExact(I.isExact());
- return NewLShr;
- }
- Value *Shift = Builder->CreateLShr(X, ShiftDiffCst);
-
- APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2));
- return BinaryOperator::CreateAnd(
- Shift, ConstantInt::get(I.getContext(), Mask));
- }
-
- // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However,
- // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
- if (I.getOpcode() == Instruction::AShr &&
- ShiftOp->getOpcode() == Instruction::Shl) {
- if (ShiftOp->hasNoSignedWrap()) {
- // (X <<nsw C1) >>s C2 --> X >>s (C2-C1)
- ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
- BinaryOperator *NewAShr =
- BinaryOperator::Create(Instruction::AShr, X, ShiftDiffCst);
- NewAShr->setIsExact(I.isExact());
- return NewAShr;
- }
- }
- } else {
- assert(ShiftAmt2 < ShiftAmt1);
- uint32_t ShiftDiff = ShiftAmt1 - ShiftAmt2;
-
- // (X >>?exact C1) << C2 --> X >>?exact (C1-C2)
- // The inexact version is deferred to DAGCombine so we don't hide shl
- // behind a bit mask.
- if (I.getOpcode() == Instruction::Shl &&
- ShiftOp->getOpcode() != Instruction::Shl && ShiftOp->isExact()) {
- ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
- BinaryOperator *NewShr =
- BinaryOperator::Create(ShiftOp->getOpcode(), X, ShiftDiffCst);
- NewShr->setIsExact(true);
- return NewShr;
- }
-
- // (X << C1) >>u C2 --> X << (C1-C2) & (-1 >> C2)
- if (I.getOpcode() == Instruction::LShr &&
- ShiftOp->getOpcode() == Instruction::Shl) {
- ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
- if (ShiftOp->hasNoUnsignedWrap()) {
- // (X <<nuw C1) >>u C2 --> X <<nuw (C1-C2)
- BinaryOperator *NewShl =
- BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst);
- NewShl->setHasNoUnsignedWrap(true);
- return NewShl;
- }
- Value *Shift = Builder->CreateShl(X, ShiftDiffCst);
-
- APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2));
- return BinaryOperator::CreateAnd(
- Shift, ConstantInt::get(I.getContext(), Mask));
- }
-
- // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However,
- // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
- if (I.getOpcode() == Instruction::AShr &&
- ShiftOp->getOpcode() == Instruction::Shl) {
- if (ShiftOp->hasNoSignedWrap()) {
- // (X <<nsw C1) >>s C2 --> X <<nsw (C1-C2)
- ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
- BinaryOperator *NewShl =
- BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst);
- NewShl->setHasNoSignedWrap(true);
- return NewShl;
- }
- }
- }
- }
-
- return nullptr;
-}
-
Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
BinaryOperator &I) {
bool isLeftShift = I.getOpcode() == Instruction::Shl;
- ConstantInt *COp1 = nullptr;
- if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(Op1))
- COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
- else if (ConstantVector *CV = dyn_cast<ConstantVector>(Op1))
- COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
- else
- COp1 = dyn_cast<ConstantInt>(Op1);
-
- if (!COp1)
+ const APInt *Op1C;
+ if (!match(Op1, m_APInt(Op1C)))
return nullptr;
// See if we can propagate this shift into the input, this covers the trivial
// cast of lshr(shl(x,c1),c2) as well as other more complex cases.
if (I.getOpcode() != Instruction::AShr &&
- CanEvaluateShifted(Op0, COp1->getZExtValue(), isLeftShift, *this, &I)) {
+ canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) {
DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression"
" to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n");
return replaceInstUsesWith(
- I, GetShiftedValue(Op0, COp1->getZExtValue(), isLeftShift, *this, DL));
+ I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL));
}
// See if we can simplify any instructions used by the instruction whose sole
// purpose is to compute bits we don't care about.
- uint32_t TypeBits = Op0->getType()->getScalarSizeInBits();
+ unsigned TypeBits = Op0->getType()->getScalarSizeInBits();
- assert(!COp1->uge(TypeBits) &&
+ assert(!Op1C->uge(TypeBits) &&
"Shift over the type width should have been removed already");
- // ((X*C1) << C2) == (X * (C1 << C2))
- if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op0))
- if (BO->getOpcode() == Instruction::Mul && isLeftShift)
- if (Constant *BOOp = dyn_cast<Constant>(BO->getOperand(1)))
- return BinaryOperator::CreateMul(BO->getOperand(0),
- ConstantExpr::getShl(BOOp, Op1));
-
if (Instruction *FoldedShift = foldOpWithConstantIntoOperand(I))
return FoldedShift;
@@ -544,7 +349,8 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
if (TrOp && I.isLogicalShift() && TrOp->isShift() &&
isa<ConstantInt>(TrOp->getOperand(1))) {
// Okay, we'll do this xform. Make the shift of shift.
- Constant *ShAmt = ConstantExpr::getZExt(COp1, TrOp->getType());
+ Constant *ShAmt =
+ ConstantExpr::getZExt(cast<Constant>(Op1), TrOp->getType());
// (shift2 (shift1 & 0x00FF), c2)
Value *NSh = Builder->CreateBinOp(I.getOpcode(), TrOp, ShAmt,I.getName());
@@ -561,10 +367,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
// shift. We know that it is a logical shift by a constant, so adjust the
// mask as appropriate.
if (I.getOpcode() == Instruction::Shl)
- MaskV <<= COp1->getZExtValue();
+ MaskV <<= Op1C->getZExtValue();
else {
assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift");
- MaskV = MaskV.lshr(COp1->getZExtValue());
+ MaskV = MaskV.lshr(Op1C->getZExtValue());
}
// shift1 & 0x00FF
@@ -598,7 +404,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
// (X + (Y << C))
Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), YS, V1,
Op0BO->getOperand(1)->getName());
- uint32_t Op1Val = COp1->getLimitedValue(TypeBits);
+ unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
Constant *Mask = ConstantInt::get(I.getContext(), Bits);
@@ -634,7 +440,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
// (X + (Y << C))
Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), V1, YS,
Op0BO->getOperand(0)->getName());
- uint32_t Op1Val = COp1->getLimitedValue(TypeBits);
+ unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
Constant *Mask = ConstantInt::get(I.getContext(), Bits);
@@ -705,9 +511,6 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
}
}
- if (Instruction *Folded = foldShiftByConstOfShiftByConst(I, COp1, Builder))
- return Folded;
-
return nullptr;
}
@@ -715,59 +518,97 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
if (Value *V = SimplifyVectorOp(I))
return replaceInstUsesWith(I, V);
- if (Value *V =
- SimplifyShlInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(),
- I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC))
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ if (Value *V = SimplifyShlInst(Op0, Op1, I.hasNoSignedWrap(),
+ I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC))
return replaceInstUsesWith(I, V);
if (Instruction *V = commonShiftTransforms(I))
return V;
- if (ConstantInt *Op1C = dyn_cast<ConstantInt>(I.getOperand(1))) {
- unsigned ShAmt = Op1C->getZExtValue();
-
- // Turn:
- // %zext = zext i32 %V to i64
- // %res = shl i64 %V, 8
- //
- // Into:
- // %shl = shl i32 %V, 8
- // %res = zext i32 %shl to i64
- //
- // This is only valid if %V would have zeros shifted out.
- if (auto *ZI = dyn_cast<ZExtInst>(I.getOperand(0))) {
- unsigned SrcBitWidth = ZI->getSrcTy()->getScalarSizeInBits();
- if (ShAmt < SrcBitWidth &&
- MaskedValueIsZero(ZI->getOperand(0),
- APInt::getHighBitsSet(SrcBitWidth, ShAmt), 0, &I)) {
- auto *Shl = Builder->CreateShl(ZI->getOperand(0), ShAmt);
- return new ZExtInst(Shl, I.getType());
+ const APInt *ShAmtAPInt;
+ if (match(Op1, m_APInt(ShAmtAPInt))) {
+ unsigned ShAmt = ShAmtAPInt->getZExtValue();
+ unsigned BitWidth = I.getType()->getScalarSizeInBits();
+ Type *Ty = I.getType();
+
+ // shl (zext X), ShAmt --> zext (shl X, ShAmt)
+ // This is only valid if X would have zeros shifted out.
+ Value *X;
+ if (match(Op0, m_ZExt(m_Value(X)))) {
+ unsigned SrcWidth = X->getType()->getScalarSizeInBits();
+ if (ShAmt < SrcWidth &&
+ MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmt), 0, &I))
+ return new ZExtInst(Builder->CreateShl(X, ShAmt), Ty);
+ }
+
+ // (X >>u C) << C --> X & (-1 << C)
+ if (match(Op0, m_LShr(m_Value(X), m_Specific(Op1)))) {
+ APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt));
+ return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
+ }
+
+ // Be careful about hiding shl instructions behind bit masks. They are used
+ // to represent multiplies by a constant, and it is important that simple
+ // arithmetic expressions are still recognizable by scalar evolution.
+ // The inexact versions are deferred to DAGCombine, so we don't hide shl
+ // behind a bit mask.
+ const APInt *ShOp1;
+ if (match(Op0, m_CombineOr(m_Exact(m_LShr(m_Value(X), m_APInt(ShOp1))),
+ m_Exact(m_AShr(m_Value(X), m_APInt(ShOp1)))))) {
+ unsigned ShrAmt = ShOp1->getZExtValue();
+ if (ShrAmt < ShAmt) {
+ // If C1 < C2: (X >>?,exact C1) << C2 --> X << (C2 - C1)
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt);
+ auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
+ NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
+ NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
+ return NewShl;
}
+ if (ShrAmt > ShAmt) {
+ // If C1 > C2: (X >>?exact C1) << C2 --> X >>?exact (C1 - C2)
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt);
+ auto *NewShr = BinaryOperator::Create(
+ cast<BinaryOperator>(Op0)->getOpcode(), X, ShiftDiff);
+ NewShr->setIsExact(true);
+ return NewShr;
+ }
+ }
+
+ if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) {
+ unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
+ // Oversized shifts are simplified to zero in InstSimplify.
+ if (AmtSum < BitWidth)
+ // (X << C1) << C2 --> X << (C1 + C2)
+ return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum));
}
// If the shifted-out value is known-zero, then this is a NUW shift.
if (!I.hasNoUnsignedWrap() &&
- MaskedValueIsZero(I.getOperand(0),
- APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt), 0,
- &I)) {
+ MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmt), 0, &I)) {
I.setHasNoUnsignedWrap();
return &I;
}
- // If the shifted out value is all signbits, this is a NSW shift.
- if (!I.hasNoSignedWrap() &&
- ComputeNumSignBits(I.getOperand(0), 0, &I) > ShAmt) {
+ // If the shifted-out value is all signbits, then this is a NSW shift.
+ if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmt) {
I.setHasNoSignedWrap();
return &I;
}
}
- // (C1 << A) << C2 -> (C1 << C2) << A
- Constant *C1, *C2;
- Value *A;
- if (match(I.getOperand(0), m_OneUse(m_Shl(m_Constant(C1), m_Value(A)))) &&
- match(I.getOperand(1), m_Constant(C2)))
- return BinaryOperator::CreateShl(ConstantExpr::getShl(C1, C2), A);
+ Constant *C1;
+ if (match(Op1, m_Constant(C1))) {
+ Constant *C2;
+ Value *X;
+ // (C2 << X) << C1 --> (C2 << C1) << X
+ if (match(Op0, m_OneUse(m_Shl(m_Constant(C2), m_Value(X)))))
+ return BinaryOperator::CreateShl(ConstantExpr::getShl(C2, C1), X);
+
+ // (X * C2) << C1 --> X * (C2 << C1)
+ if (match(Op0, m_Mul(m_Value(X), m_Constant(C2))))
+ return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1));
+ }
return nullptr;
}
@@ -776,43 +617,83 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
if (Value *V = SimplifyVectorOp(I))
return replaceInstUsesWith(I, V);
- if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
- DL, &TLI, &DT, &AC))
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ if (Value *V = SimplifyLShrInst(Op0, Op1, I.isExact(), DL, &TLI, &DT, &AC))
return replaceInstUsesWith(I, V);
if (Instruction *R = commonShiftTransforms(I))
return R;
- Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
-
- if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) {
- unsigned ShAmt = Op1C->getZExtValue();
-
- if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) {
- unsigned BitWidth = Op0->getType()->getScalarSizeInBits();
+ Type *Ty = I.getType();
+ const APInt *ShAmtAPInt;
+ if (match(Op1, m_APInt(ShAmtAPInt))) {
+ unsigned ShAmt = ShAmtAPInt->getZExtValue();
+ unsigned BitWidth = Ty->getScalarSizeInBits();
+ auto *II = dyn_cast<IntrinsicInst>(Op0);
+ if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt &&
+ (II->getIntrinsicID() == Intrinsic::ctlz ||
+ II->getIntrinsicID() == Intrinsic::cttz ||
+ II->getIntrinsicID() == Intrinsic::ctpop)) {
// ctlz.i32(x)>>5 --> zext(x == 0)
// cttz.i32(x)>>5 --> zext(x == 0)
// ctpop.i32(x)>>5 --> zext(x == -1)
- if ((II->getIntrinsicID() == Intrinsic::ctlz ||
- II->getIntrinsicID() == Intrinsic::cttz ||
- II->getIntrinsicID() == Intrinsic::ctpop) &&
- isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt) {
- bool isCtPop = II->getIntrinsicID() == Intrinsic::ctpop;
- Constant *RHS = ConstantInt::getSigned(Op0->getType(), isCtPop ? -1:0);
- Value *Cmp = Builder->CreateICmpEQ(II->getArgOperand(0), RHS);
- return new ZExtInst(Cmp, II->getType());
+ bool IsPop = II->getIntrinsicID() == Intrinsic::ctpop;
+ Constant *RHS = ConstantInt::getSigned(Ty, IsPop ? -1 : 0);
+ Value *Cmp = Builder->CreateICmpEQ(II->getArgOperand(0), RHS);
+ return new ZExtInst(Cmp, Ty);
+ }
+
+ Value *X;
+ const APInt *ShOp1;
+ if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) {
+ unsigned ShlAmt = ShOp1->getZExtValue();
+ if (ShlAmt < ShAmt) {
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt);
+ if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
+ // (X <<nuw C1) >>u C2 --> X >>u (C2 - C1)
+ auto *NewLShr = BinaryOperator::CreateLShr(X, ShiftDiff);
+ NewLShr->setIsExact(I.isExact());
+ return NewLShr;
+ }
+ // (X << C1) >>u C2 --> (X >>u (C2 - C1)) & (-1 >> C2)
+ Value *NewLShr = Builder->CreateLShr(X, ShiftDiff, "", I.isExact());
+ APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
+ return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask));
}
+ if (ShlAmt > ShAmt) {
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt);
+ if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
+ // (X <<nuw C1) >>u C2 --> X <<nuw (C1 - C2)
+ auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
+ NewShl->setHasNoUnsignedWrap(true);
+ return NewShl;
+ }
+ // (X << C1) >>u C2 --> X << (C1 - C2) & (-1 >> C2)
+ Value *NewShl = Builder->CreateShl(X, ShiftDiff);
+ APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
+ return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask));
+ }
+ assert(ShlAmt == ShAmt);
+ // (X << C) >>u C --> X & (-1 >>u C)
+ APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
+ return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
+ }
+
+ if (match(Op0, m_LShr(m_Value(X), m_APInt(ShOp1)))) {
+ unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
+ // Oversized shifts are simplified to zero in InstSimplify.
+ if (AmtSum < BitWidth)
+ // (X >>u C1) >>u C2 --> X >>u (C1 + C2)
+ return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum));
}
// If the shifted-out value is known-zero, then this is an exact shift.
if (!I.isExact() &&
- MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt),
- 0, &I)){
+ MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
I.setIsExact();
return &I;
}
}
-
return nullptr;
}
@@ -820,48 +701,66 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
if (Value *V = SimplifyVectorOp(I))
return replaceInstUsesWith(I, V);
- if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
- DL, &TLI, &DT, &AC))
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ if (Value *V = SimplifyAShrInst(Op0, Op1, I.isExact(), DL, &TLI, &DT, &AC))
return replaceInstUsesWith(I, V);
if (Instruction *R = commonShiftTransforms(I))
return R;
- Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
-
- if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) {
- unsigned ShAmt = Op1C->getZExtValue();
+ Type *Ty = I.getType();
+ unsigned BitWidth = Ty->getScalarSizeInBits();
+ const APInt *ShAmtAPInt;
+ if (match(Op1, m_APInt(ShAmtAPInt))) {
+ unsigned ShAmt = ShAmtAPInt->getZExtValue();
- // If the input is a SHL by the same constant (ashr (shl X, C), C), then we
- // have a sign-extend idiom.
+ // If the shift amount equals the difference in width of the destination
+ // and source scalar types:
+ // ashr (shl (zext X), C), C --> sext X
Value *X;
- if (match(Op0, m_Shl(m_Value(X), m_Specific(Op1)))) {
- // If the input is an extension from the shifted amount value, e.g.
- // %x = zext i8 %A to i32
- // %y = shl i32 %x, 24
- // %z = ashr %y, 24
- // then turn this into "z = sext i8 A to i32".
- if (ZExtInst *ZI = dyn_cast<ZExtInst>(X)) {
- uint32_t SrcBits = ZI->getOperand(0)->getType()->getScalarSizeInBits();
- uint32_t DestBits = ZI->getType()->getScalarSizeInBits();
- if (Op1C->getZExtValue() == DestBits-SrcBits)
- return new SExtInst(ZI->getOperand(0), ZI->getType());
+ if (match(Op0, m_Shl(m_ZExt(m_Value(X)), m_Specific(Op1))) &&
+ ShAmt == BitWidth - X->getType()->getScalarSizeInBits())
+ return new SExtInst(X, Ty);
+
+ // We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However,
+ // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
+ const APInt *ShOp1;
+ if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1)))) {
+ unsigned ShlAmt = ShOp1->getZExtValue();
+ if (ShlAmt < ShAmt) {
+ // (X <<nsw C1) >>s C2 --> X >>s (C2 - C1)
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt);
+ auto *NewAShr = BinaryOperator::CreateAShr(X, ShiftDiff);
+ NewAShr->setIsExact(I.isExact());
+ return NewAShr;
}
+ if (ShlAmt > ShAmt) {
+ // (X <<nsw C1) >>s C2 --> X <<nsw (C1 - C2)
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt);
+ auto *NewShl = BinaryOperator::Create(Instruction::Shl, X, ShiftDiff);
+ NewShl->setHasNoSignedWrap(true);
+ return NewShl;
+ }
+ }
+
+ if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1)))) {
+ unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
+ // Oversized arithmetic shifts replicate the sign bit.
+ AmtSum = std::min(AmtSum, BitWidth - 1);
+ // (X >>s C1) >>s C2 --> X >>s (C1 + C2)
+ return BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum));
}
// If the shifted-out value is known-zero, then this is an exact shift.
if (!I.isExact() &&
- MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt),
- 0, &I)) {
+ MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
I.setIsExact();
return &I;
}
}
// See if we can turn a signed shr into an unsigned shr.
- if (MaskedValueIsZero(Op0,
- APInt::getSignBit(I.getType()->getScalarSizeInBits()),
- 0, &I))
+ if (MaskedValueIsZero(Op0, APInt::getSignBit(BitWidth), 0, &I))
return BinaryOperator::CreateLShr(Op0, Op1);
return nullptr;