diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineShifts.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineShifts.cpp | 85 |
1 files changed, 52 insertions, 33 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index 44bbb84686ab..34f8037e519f 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -87,8 +87,7 @@ static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl, // Equal shift amounts in opposite directions become bitwise 'and': // lshr (shl X, C), C --> and X, C' // shl (lshr X, C), C --> and X, C' - unsigned InnerShAmt = InnerShiftConst->getZExtValue(); - if (InnerShAmt == OuterShAmt) + if (*InnerShiftConst == OuterShAmt) return true; // If the 2nd shift is bigger than the 1st, we can fold: @@ -98,7 +97,8 @@ static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl, // Also, check that the inner shift is valid (less than the type width) or // we'll crash trying to produce the bit mask for the 'and'. unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits(); - if (InnerShAmt > OuterShAmt && InnerShAmt < TypeWidth) { + if (InnerShiftConst->ugt(OuterShAmt) && InnerShiftConst->ult(TypeWidth)) { + unsigned InnerShAmt = InnerShiftConst->getZExtValue(); unsigned MaskShift = IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt; APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift; @@ -135,7 +135,7 @@ static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, ConstantInt *CI = nullptr; if ((IsLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) || (!IsLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) { - if (CI->getZExtValue() == NumBits) { + if (CI->getValue() == NumBits) { // TODO: Check that the input bits are already zero with MaskedValueIsZero #if 0 // If this is a truncate of a logical shr, we can truncate it to a smaller @@ -356,8 +356,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // cast of lshr(shl(x,c1),c2) as well as other more complex cases. if (I.getOpcode() != Instruction::AShr && canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) { - DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression" - " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n"); + LLVM_DEBUG( + dbgs() << "ICE: GetShiftedValue propagating shift through expression" + " to eliminate shift:\n IN: " + << *Op0 << "\n SH: " << I << "\n"); return replaceInstUsesWith( I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL)); @@ -370,7 +372,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, assert(!Op1C->uge(TypeBits) && "Shift over the type width should have been removed already"); - if (Instruction *FoldedShift = foldOpWithConstantIntoOperand(I)) + if (Instruction *FoldedShift = foldBinOpIntoSelectOrPhi(I)) return FoldedShift; // Fold shift2(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2)) @@ -586,23 +588,23 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, } Instruction *InstCombiner::visitShl(BinaryOperator &I) { - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1), + I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (Value *V = - SimplifyShlInst(Op0, Op1, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), - SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; if (Instruction *V = commonShiftTransforms(I)) return V; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = I.getType(); const APInt *ShAmtAPInt; if (match(Op1, m_APInt(ShAmtAPInt))) { unsigned ShAmt = ShAmtAPInt->getZExtValue(); - unsigned BitWidth = I.getType()->getScalarSizeInBits(); - Type *Ty = I.getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); // shl (zext X), ShAmt --> zext (shl X, ShAmt) // This is only valid if X would have zeros shifted out. @@ -620,11 +622,8 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); } - // Be careful about hiding shl instructions behind bit masks. They are used - // to represent multiplies by a constant, and it is important that simple - // arithmetic expressions are still recognizable by scalar evolution. - // The inexact versions are deferred to DAGCombine, so we don't hide shl - // behind a bit mask. + // FIXME: we do not yet transform non-exact shr's. The backend (DAGCombine) + // needs a few fixes for the rotate pattern recognition first. const APInt *ShOp1; if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(ShOp1))))) { unsigned ShrAmt = ShOp1->getZExtValue(); @@ -668,6 +667,15 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { } } + // Transform (x >> y) << y to x & (-1 << y) + // Valid for any type of right-shift. + Value *X; + if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))))) { + Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); + Value *Mask = Builder.CreateShl(AllOnes, Op1); + return BinaryOperator::CreateAnd(Mask, X); + } + Constant *C1; if (match(Op1, m_Constant(C1))) { Constant *C2; @@ -685,17 +693,17 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { } Instruction *InstCombiner::visitLShr(BinaryOperator &I) { - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (Value *V = - SimplifyLShrInst(Op0, Op1, I.isExact(), SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; if (Instruction *R = commonShiftTransforms(I)) return R; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); const APInt *ShAmtAPInt; if (match(Op1, m_APInt(ShAmtAPInt))) { @@ -800,25 +808,34 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { return &I; } } + + // Transform (x << y) >> y to x & (-1 >> y) + Value *X; + if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) { + Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); + Value *Mask = Builder.CreateLShr(AllOnes, Op1); + return BinaryOperator::CreateAnd(Mask, X); + } + return nullptr; } Instruction *InstCombiner::visitAShr(BinaryOperator &I) { - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (Value *V = - SimplifyAShrInst(Op0, Op1, I.isExact(), SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; if (Instruction *R = commonShiftTransforms(I)) return R; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); unsigned BitWidth = Ty->getScalarSizeInBits(); const APInt *ShAmtAPInt; - if (match(Op1, m_APInt(ShAmtAPInt))) { + if (match(Op1, m_APInt(ShAmtAPInt)) && ShAmtAPInt->ult(BitWidth)) { unsigned ShAmt = ShAmtAPInt->getZExtValue(); // If the shift amount equals the difference in width of the destination @@ -832,7 +849,8 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { // We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However, // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. const APInt *ShOp1; - if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1)))) { + if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1))) && + ShOp1->ult(BitWidth)) { unsigned ShlAmt = ShOp1->getZExtValue(); if (ShlAmt < ShAmt) { // (X <<nsw C1) >>s C2 --> X >>s (C2 - C1) @@ -850,7 +868,8 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { } } - if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1)))) { + if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1))) && + ShOp1->ult(BitWidth)) { unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); // Oversized arithmetic shifts replicate the sign bit. AmtSum = std::min(AmtSum, BitWidth - 1); |