summaryrefslogtreecommitdiff
path: root/lib/Transforms/InstCombine/InstCombineShifts.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineShifts.cpp')
-rw-r--r--lib/Transforms/InstCombine/InstCombineShifts.cpp85
1 files changed, 52 insertions, 33 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 44bbb84686ab..34f8037e519f 100644
--- a/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -87,8 +87,7 @@ static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl,
// Equal shift amounts in opposite directions become bitwise 'and':
// lshr (shl X, C), C --> and X, C'
// shl (lshr X, C), C --> and X, C'
- unsigned InnerShAmt = InnerShiftConst->getZExtValue();
- if (InnerShAmt == OuterShAmt)
+ if (*InnerShiftConst == OuterShAmt)
return true;
// If the 2nd shift is bigger than the 1st, we can fold:
@@ -98,7 +97,8 @@ static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl,
// Also, check that the inner shift is valid (less than the type width) or
// we'll crash trying to produce the bit mask for the 'and'.
unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits();
- if (InnerShAmt > OuterShAmt && InnerShAmt < TypeWidth) {
+ if (InnerShiftConst->ugt(OuterShAmt) && InnerShiftConst->ult(TypeWidth)) {
+ unsigned InnerShAmt = InnerShiftConst->getZExtValue();
unsigned MaskShift =
IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt;
APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift;
@@ -135,7 +135,7 @@ static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
ConstantInt *CI = nullptr;
if ((IsLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) ||
(!IsLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) {
- if (CI->getZExtValue() == NumBits) {
+ if (CI->getValue() == NumBits) {
// TODO: Check that the input bits are already zero with MaskedValueIsZero
#if 0
// If this is a truncate of a logical shr, we can truncate it to a smaller
@@ -356,8 +356,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
// cast of lshr(shl(x,c1),c2) as well as other more complex cases.
if (I.getOpcode() != Instruction::AShr &&
canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) {
- DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression"
- " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n");
+ LLVM_DEBUG(
+ dbgs() << "ICE: GetShiftedValue propagating shift through expression"
+ " to eliminate shift:\n IN: "
+ << *Op0 << "\n SH: " << I << "\n");
return replaceInstUsesWith(
I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL));
@@ -370,7 +372,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
assert(!Op1C->uge(TypeBits) &&
"Shift over the type width should have been removed already");
- if (Instruction *FoldedShift = foldOpWithConstantIntoOperand(I))
+ if (Instruction *FoldedShift = foldBinOpIntoSelectOrPhi(I))
return FoldedShift;
// Fold shift2(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2))
@@ -586,23 +588,23 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
}
Instruction *InstCombiner::visitShl(BinaryOperator &I) {
- if (Value *V = SimplifyVectorOp(I))
+ if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1),
+ I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
+ SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
- Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
- if (Value *V =
- SimplifyShlInst(Op0, Op1, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
- SQ.getWithInstruction(&I)))
- return replaceInstUsesWith(I, V);
+ if (Instruction *X = foldShuffledBinop(I))
+ return X;
if (Instruction *V = commonShiftTransforms(I))
return V;
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ Type *Ty = I.getType();
const APInt *ShAmtAPInt;
if (match(Op1, m_APInt(ShAmtAPInt))) {
unsigned ShAmt = ShAmtAPInt->getZExtValue();
- unsigned BitWidth = I.getType()->getScalarSizeInBits();
- Type *Ty = I.getType();
+ unsigned BitWidth = Ty->getScalarSizeInBits();
// shl (zext X), ShAmt --> zext (shl X, ShAmt)
// This is only valid if X would have zeros shifted out.
@@ -620,11 +622,8 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
}
- // Be careful about hiding shl instructions behind bit masks. They are used
- // to represent multiplies by a constant, and it is important that simple
- // arithmetic expressions are still recognizable by scalar evolution.
- // The inexact versions are deferred to DAGCombine, so we don't hide shl
- // behind a bit mask.
+ // FIXME: we do not yet transform non-exact shr's. The backend (DAGCombine)
+ // needs a few fixes for the rotate pattern recognition first.
const APInt *ShOp1;
if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(ShOp1))))) {
unsigned ShrAmt = ShOp1->getZExtValue();
@@ -668,6 +667,15 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
}
}
+ // Transform (x >> y) << y to x & (-1 << y)
+ // Valid for any type of right-shift.
+ Value *X;
+ if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))))) {
+ Constant *AllOnes = ConstantInt::getAllOnesValue(Ty);
+ Value *Mask = Builder.CreateShl(AllOnes, Op1);
+ return BinaryOperator::CreateAnd(Mask, X);
+ }
+
Constant *C1;
if (match(Op1, m_Constant(C1))) {
Constant *C2;
@@ -685,17 +693,17 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
}
Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
- if (Value *V = SimplifyVectorOp(I))
+ if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
+ SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
- Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
- if (Value *V =
- SimplifyLShrInst(Op0, Op1, I.isExact(), SQ.getWithInstruction(&I)))
- return replaceInstUsesWith(I, V);
+ if (Instruction *X = foldShuffledBinop(I))
+ return X;
if (Instruction *R = commonShiftTransforms(I))
return R;
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Type *Ty = I.getType();
const APInt *ShAmtAPInt;
if (match(Op1, m_APInt(ShAmtAPInt))) {
@@ -800,25 +808,34 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
return &I;
}
}
+
+ // Transform (x << y) >> y to x & (-1 >> y)
+ Value *X;
+ if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) {
+ Constant *AllOnes = ConstantInt::getAllOnesValue(Ty);
+ Value *Mask = Builder.CreateLShr(AllOnes, Op1);
+ return BinaryOperator::CreateAnd(Mask, X);
+ }
+
return nullptr;
}
Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
- if (Value *V = SimplifyVectorOp(I))
+ if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
+ SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
- Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
- if (Value *V =
- SimplifyAShrInst(Op0, Op1, I.isExact(), SQ.getWithInstruction(&I)))
- return replaceInstUsesWith(I, V);
+ if (Instruction *X = foldShuffledBinop(I))
+ return X;
if (Instruction *R = commonShiftTransforms(I))
return R;
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Type *Ty = I.getType();
unsigned BitWidth = Ty->getScalarSizeInBits();
const APInt *ShAmtAPInt;
- if (match(Op1, m_APInt(ShAmtAPInt))) {
+ if (match(Op1, m_APInt(ShAmtAPInt)) && ShAmtAPInt->ult(BitWidth)) {
unsigned ShAmt = ShAmtAPInt->getZExtValue();
// If the shift amount equals the difference in width of the destination
@@ -832,7 +849,8 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
// We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However,
// we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
const APInt *ShOp1;
- if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1)))) {
+ if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1))) &&
+ ShOp1->ult(BitWidth)) {
unsigned ShlAmt = ShOp1->getZExtValue();
if (ShlAmt < ShAmt) {
// (X <<nsw C1) >>s C2 --> X >>s (C2 - C1)
@@ -850,7 +868,8 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
}
}
- if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1)))) {
+ if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1))) &&
+ ShOp1->ult(BitWidth)) {
unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
// Oversized arithmetic shifts replicate the sign bit.
AmtSum = std::min(AmtSum, BitWidth - 1);