diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp')
| -rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp | 163 |
1 files changed, 108 insertions, 55 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 89dad455f015..b7958978c450 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -136,9 +136,14 @@ Value *InstCombinerImpl::reassociateShiftAmtsOfTwoSameDirectionShifts( assert(IdenticalShOpcodes && "Should not get here with different shifts."); - // All good, we can do this fold. - NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, X->getType()); + if (NewShAmt->getType() != X->getType()) { + NewShAmt = ConstantFoldCastOperand(Instruction::ZExt, NewShAmt, + X->getType(), SQ.DL); + if (!NewShAmt) + return nullptr; + } + // All good, we can do this fold. BinaryOperator *NewShift = BinaryOperator::Create(ShiftOpcode, X, NewShAmt); // The flags can only be propagated if there wasn't a trunc. @@ -245,7 +250,11 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, SumOfShAmts = Constant::replaceUndefsWith( SumOfShAmts, ConstantInt::get(SumOfShAmts->getType()->getScalarType(), ExtendedTy->getScalarSizeInBits())); - auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy); + auto *ExtendedSumOfShAmts = ConstantFoldCastOperand( + Instruction::ZExt, SumOfShAmts, ExtendedTy, Q.DL); + if (!ExtendedSumOfShAmts) + return nullptr; + // And compute the mask as usual: ~(-1 << (SumOfShAmts)) auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy); auto *ExtendedInvertedMask = @@ -278,16 +287,22 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, ShAmtsDiff = Constant::replaceUndefsWith( ShAmtsDiff, ConstantInt::get(ShAmtsDiff->getType()->getScalarType(), -WidestTyBitWidth)); - auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt( + auto *ExtendedNumHighBitsToClear = ConstantFoldCastOperand( + Instruction::ZExt, ConstantExpr::getSub(ConstantInt::get(ShAmtsDiff->getType(), WidestTyBitWidth, /*isSigned=*/false), ShAmtsDiff), - ExtendedTy); + ExtendedTy, Q.DL); + if (!ExtendedNumHighBitsToClear) + return nullptr; + // And compute the mask as usual: (-1 l>> (NumHighBitsToClear)) auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy); - NewMask = - ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear); + NewMask = ConstantFoldBinaryOpOperands(Instruction::LShr, ExtendedAllOnes, + ExtendedNumHighBitsToClear, Q.DL); + if (!NewMask) + return nullptr; } else return nullptr; // Don't know anything about this pattern. @@ -545,8 +560,8 @@ static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl, /// this succeeds, getShiftedValue() will be called to produce the value. static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, InstCombinerImpl &IC, Instruction *CxtI) { - // We can always evaluate constants shifted. - if (isa<Constant>(V)) + // We can always evaluate immediate constants. + if (match(V, m_ImmConstant())) return true; Instruction *I = dyn_cast<Instruction>(V); @@ -709,13 +724,13 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, case Instruction::Mul: { assert(!isLeftShift && "Unexpected shift direction!"); auto *Neg = BinaryOperator::CreateNeg(I->getOperand(0)); - IC.InsertNewInstWith(Neg, *I); + IC.InsertNewInstWith(Neg, I->getIterator()); unsigned TypeWidth = I->getType()->getScalarSizeInBits(); APInt Mask = APInt::getLowBitsSet(TypeWidth, TypeWidth - NumBits); auto *And = BinaryOperator::CreateAnd(Neg, ConstantInt::get(I->getType(), Mask)); And->takeName(I); - return IC.InsertNewInstWith(And, *I); + return IC.InsertNewInstWith(And, I->getIterator()); } } } @@ -745,7 +760,7 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1, // (C2 >> X) >> C1 --> (C2 >> C1) >> X Constant *C2; Value *X; - if (match(Op0, m_BinOp(I.getOpcode(), m_Constant(C2), m_Value(X)))) + if (match(Op0, m_BinOp(I.getOpcode(), m_ImmConstant(C2), m_Value(X)))) return BinaryOperator::Create( I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), C2, C1), X); @@ -928,6 +943,60 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) { return new ZExtInst(Overflow, Ty); } +// Try to set nuw/nsw flags on shl or exact flag on lshr/ashr using knownbits. +static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) { + assert(I.isShift() && "Expected a shift as input"); + // We already have all the flags. + if (I.getOpcode() == Instruction::Shl) { + if (I.hasNoUnsignedWrap() && I.hasNoSignedWrap()) + return false; + } else { + if (I.isExact()) + return false; + + // shr (shl X, Y), Y + if (match(I.getOperand(0), m_Shl(m_Value(), m_Specific(I.getOperand(1))))) { + I.setIsExact(); + return true; + } + } + + // Compute what we know about shift count. + KnownBits KnownCnt = computeKnownBits(I.getOperand(1), /* Depth */ 0, Q); + unsigned BitWidth = KnownCnt.getBitWidth(); + // Since shift produces a poison value if RHS is equal to or larger than the + // bit width, we can safely assume that RHS is less than the bit width. + uint64_t MaxCnt = KnownCnt.getMaxValue().getLimitedValue(BitWidth - 1); + + KnownBits KnownAmt = computeKnownBits(I.getOperand(0), /* Depth */ 0, Q); + bool Changed = false; + + if (I.getOpcode() == Instruction::Shl) { + // If we have as many leading zeros than maximum shift cnt we have nuw. + if (!I.hasNoUnsignedWrap() && MaxCnt <= KnownAmt.countMinLeadingZeros()) { + I.setHasNoUnsignedWrap(); + Changed = true; + } + // If we have more sign bits than maximum shift cnt we have nsw. + if (!I.hasNoSignedWrap()) { + if (MaxCnt < KnownAmt.countMinSignBits() || + MaxCnt < ComputeNumSignBits(I.getOperand(0), Q.DL, /*Depth*/ 0, Q.AC, + Q.CxtI, Q.DT)) { + I.setHasNoSignedWrap(); + Changed = true; + } + } + return Changed; + } + + // If we have at least as many trailing zeros as maximum count then we have + // exact. + Changed = MaxCnt <= KnownAmt.countMinTrailingZeros(); + I.setIsExact(Changed); + + return Changed; +} + Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { const SimplifyQuery Q = SQ.getWithInstruction(&I); @@ -976,7 +1045,11 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { // If C1 < C: (X >>?,exact C1) << C --> X << (C - C1) Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt); auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); - NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + NewShl->setHasNoUnsignedWrap( + I.hasNoUnsignedWrap() || + (ShrAmt && + cast<Instruction>(Op0)->getOpcode() == Instruction::LShr && + I.hasNoSignedWrap())); NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); return NewShl; } @@ -997,7 +1070,11 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { // If C1 < C: (X >>? C1) << C --> (X << (C - C1)) & (-1 << C) Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt); auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); - NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + NewShl->setHasNoUnsignedWrap( + I.hasNoUnsignedWrap() || + (ShrAmt && + cast<Instruction>(Op0)->getOpcode() == Instruction::LShr && + I.hasNoSignedWrap())); NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); Builder.Insert(NewShl); APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC)); @@ -1108,22 +1185,11 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { Value *NewShift = Builder.CreateShl(X, Op1); return BinaryOperator::CreateSub(NewLHS, NewShift); } - - // If the shifted-out value is known-zero, then this is a NUW shift. - if (!I.hasNoUnsignedWrap() && - MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmtC), 0, - &I)) { - I.setHasNoUnsignedWrap(); - return &I; - } - - // If the shifted-out value is all signbits, then this is a NSW shift. - if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmtC) { - I.setHasNoSignedWrap(); - return &I; - } } + if (setShiftFlags(I, Q)) + return &I; + // Transform (x >> y) << y to x & (-1 << y) // Valid for any type of right-shift. Value *X; @@ -1161,15 +1227,6 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { Value *NegX = Builder.CreateNeg(X, "neg"); return BinaryOperator::CreateAnd(NegX, X); } - - // The only way to shift out the 1 is with an over-shift, so that would - // be poison with or without "nuw". Undef is excluded because (undef << X) - // is not undef (it is zero). - Constant *ConstantOne = cast<Constant>(Op0); - if (!I.hasNoUnsignedWrap() && !ConstantOne->containsUndefElement()) { - I.setHasNoUnsignedWrap(); - return &I; - } } return nullptr; @@ -1235,9 +1292,10 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { unsigned ShlAmtC = C1->getZExtValue(); Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmtC - ShAmtC); if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { - // (X <<nuw C1) >>u C --> X <<nuw (C1 - C) + // (X <<nuw C1) >>u C --> X <<nuw/nsw (C1 - C) auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); NewShl->setHasNoUnsignedWrap(true); + NewShl->setHasNoSignedWrap(ShAmtC > 0); return NewShl; } if (Op0->hasOneUse()) { @@ -1370,12 +1428,13 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { if (Op0->hasOneUse()) { APInt NewMulC = MulC->lshr(ShAmtC); // if c is divisible by (1 << ShAmtC): - // lshr (mul nuw x, MulC), ShAmtC -> mul nuw x, (MulC >> ShAmtC) + // lshr (mul nuw x, MulC), ShAmtC -> mul nuw nsw x, (MulC >> ShAmtC) if (MulC->eq(NewMulC.shl(ShAmtC))) { auto *NewMul = BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC)); - BinaryOperator *OrigMul = cast<BinaryOperator>(Op0); - NewMul->setHasNoSignedWrap(OrigMul->hasNoSignedWrap()); + assert(ShAmtC != 0 && + "lshr X, 0 should be handled by simplifyLShrInst."); + NewMul->setHasNoSignedWrap(true); return NewMul; } } @@ -1414,15 +1473,12 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { Value *And = Builder.CreateAnd(BoolX, BoolY); return new ZExtInst(And, Ty); } - - // If the shifted-out value is known-zero, then this is an exact shift. - if (!I.isExact() && - MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmtC), 0, &I)) { - I.setIsExact(); - return &I; - } } + const SimplifyQuery Q = SQ.getWithInstruction(&I); + if (setShiftFlags(I, Q)) + return &I; + // Transform (x << y) >> y to x & (-1 >> y) if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) { Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); @@ -1581,15 +1637,12 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y))))) return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty); } - - // If the shifted-out value is known-zero, then this is an exact shift. - if (!I.isExact() && - MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) { - I.setIsExact(); - return &I; - } } + const SimplifyQuery Q = SQ.getWithInstruction(&I); + if (setShiftFlags(I, Q)) + return &I; + // Prefer `-(x & 1)` over `(x << (bitwidth(x)-1)) a>> (bitwidth(x)-1)` // as the pattern to splat the lowest bit. // FIXME: iff X is already masked, we don't need the one-use check. |
