diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp')
| -rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp | 157 |
1 files changed, 136 insertions, 21 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 13c98b935adf..ec505381cc86 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -346,8 +346,8 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, Value *X, *Y; auto matchFirstShift = [&](Value *V) { APInt Threshold(Ty->getScalarSizeInBits(), Ty->getScalarSizeInBits()); - return match(V, m_BinOp(ShiftOpcode, m_Value(), m_Value())) && - match(V, m_OneUse(m_Shift(m_Value(X), m_Constant(C0)))) && + return match(V, + m_OneUse(m_BinOp(ShiftOpcode, m_Value(X), m_Constant(C0)))) && match(ConstantExpr::getAdd(C0, C1), m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold)); }; @@ -363,7 +363,7 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1) Constant *ShiftSumC = ConstantExpr::getAdd(C0, C1); Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC); - Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, I.getOperand(1)); + Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, C1); return BinaryOperator::Create(LogicInst->getOpcode(), NewShift1, NewShift2); } @@ -730,13 +730,34 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1, return BinaryOperator::Create( I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), C2, C1), X); + bool IsLeftShift = I.getOpcode() == Instruction::Shl; + Type *Ty = I.getType(); + unsigned TypeBits = Ty->getScalarSizeInBits(); + + // (X / +DivC) >> (Width - 1) --> ext (X <= -DivC) + // (X / -DivC) >> (Width - 1) --> ext (X >= +DivC) + const APInt *DivC; + if (!IsLeftShift && match(C1, m_SpecificIntAllowUndef(TypeBits - 1)) && + match(Op0, m_SDiv(m_Value(X), m_APInt(DivC))) && !DivC->isZero() && + !DivC->isMinSignedValue()) { + Constant *NegDivC = ConstantInt::get(Ty, -(*DivC)); + ICmpInst::Predicate Pred = + DivC->isNegative() ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_SLE; + Value *Cmp = Builder.CreateICmp(Pred, X, NegDivC); + auto ExtOpcode = (I.getOpcode() == Instruction::AShr) ? Instruction::SExt + : Instruction::ZExt; + return CastInst::Create(ExtOpcode, Cmp, Ty); + } + const APInt *Op1C; if (!match(C1, m_APInt(Op1C))) return nullptr; + assert(!Op1C->uge(TypeBits) && + "Shift over the type width should have been removed already"); + // See if we can propagate this shift into the input, this covers the trivial // cast of lshr(shl(x,c1),c2) as well as other more complex cases. - bool IsLeftShift = I.getOpcode() == Instruction::Shl; if (I.getOpcode() != Instruction::AShr && canEvaluateShifted(Op0, Op1C->getZExtValue(), IsLeftShift, *this, &I)) { LLVM_DEBUG( @@ -748,14 +769,6 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1, I, getShiftedValue(Op0, Op1C->getZExtValue(), IsLeftShift, *this, DL)); } - // See if we can simplify any instructions used by the instruction whose sole - // purpose is to compute bits we don't care about. - Type *Ty = I.getType(); - unsigned TypeBits = Ty->getScalarSizeInBits(); - assert(!Op1C->uge(TypeBits) && - "Shift over the type width should have been removed already"); - (void)TypeBits; - if (Instruction *FoldedShift = foldBinOpIntoSelectOrPhi(I)) return FoldedShift; @@ -826,6 +839,74 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1, return nullptr; } +// Tries to perform +// (lshr (add (zext X), (zext Y)), K) +// -> (icmp ult (add X, Y), X) +// where +// - The add's operands are zexts from a K-bits integer to a bigger type. +// - The add is only used by the shr, or by iK (or narrower) truncates. +// - The lshr type has more than 2 bits (other types are boolean math). +// - K > 1 +// note that +// - The resulting add cannot have nuw/nsw, else on overflow we get a +// poison value and the transform isn't legal anymore. +Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) { + assert(I.getOpcode() == Instruction::LShr); + + Value *Add = I.getOperand(0); + Value *ShiftAmt = I.getOperand(1); + Type *Ty = I.getType(); + + if (Ty->getScalarSizeInBits() < 3) + return nullptr; + + const APInt *ShAmtAPInt = nullptr; + Value *X = nullptr, *Y = nullptr; + if (!match(ShiftAmt, m_APInt(ShAmtAPInt)) || + !match(Add, + m_Add(m_OneUse(m_ZExt(m_Value(X))), m_OneUse(m_ZExt(m_Value(Y)))))) + return nullptr; + + const unsigned ShAmt = ShAmtAPInt->getZExtValue(); + if (ShAmt == 1) + return nullptr; + + // X/Y are zexts from `ShAmt`-sized ints. + if (X->getType()->getScalarSizeInBits() != ShAmt || + Y->getType()->getScalarSizeInBits() != ShAmt) + return nullptr; + + // Make sure that `Add` is only used by `I` and `ShAmt`-truncates. + if (!Add->hasOneUse()) { + for (User *U : Add->users()) { + if (U == &I) + continue; + + TruncInst *Trunc = dyn_cast<TruncInst>(U); + if (!Trunc || Trunc->getType()->getScalarSizeInBits() > ShAmt) + return nullptr; + } + } + + // Insert at Add so that the newly created `NarrowAdd` will dominate it's + // users (i.e. `Add`'s users). + Instruction *AddInst = cast<Instruction>(Add); + Builder.SetInsertPoint(AddInst); + + Value *NarrowAdd = Builder.CreateAdd(X, Y, "add.narrowed"); + Value *Overflow = + Builder.CreateICmpULT(NarrowAdd, X, "add.narrowed.overflow"); + + // Replace the uses of the original add with a zext of the + // NarrowAdd's result. Note that all users at this stage are known to + // be ShAmt-sized truncs, or the lshr itself. + if (!Add->hasOneUse()) + replaceInstUsesWith(*AddInst, Builder.CreateZExt(NarrowAdd, Ty)); + + // Replace the LShr with a zext of the overflow check. + return new ZExtInst(Overflow, Ty); +} + Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { const SimplifyQuery Q = SQ.getWithInstruction(&I); @@ -1046,11 +1127,21 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { } } - // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1 - if (match(Op0, m_One()) && - match(Op1, m_Sub(m_SpecificInt(BitWidth - 1), m_Value(X)))) - return BinaryOperator::CreateLShr( - ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X); + if (match(Op0, m_One())) { + // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1 + if (match(Op1, m_Sub(m_SpecificInt(BitWidth - 1), m_Value(X)))) + return BinaryOperator::CreateLShr( + ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X); + + // The only way to shift out the 1 is with an over-shift, so that would + // be poison with or without "nuw". Undef is excluded because (undef << X) + // is not undef (it is zero). + Constant *ConstantOne = cast<Constant>(Op0); + if (!I.hasNoUnsignedWrap() && !ConstantOne->containsUndefElement()) { + I.setHasNoUnsignedWrap(); + return &I; + } + } return nullptr; } @@ -1068,10 +1159,17 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); + Value *X; const APInt *C; + unsigned BitWidth = Ty->getScalarSizeInBits(); + + // (iN (~X) u>> (N - 1)) --> zext (X > -1) + if (match(Op0, m_OneUse(m_Not(m_Value(X)))) && + match(Op1, m_SpecificIntAllowUndef(BitWidth - 1))) + return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty); + if (match(Op1, m_APInt(C))) { unsigned ShAmtC = C->getZExtValue(); - unsigned BitWidth = Ty->getScalarSizeInBits(); auto *II = dyn_cast<IntrinsicInst>(Op0); if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmtC && (II->getIntrinsicID() == Intrinsic::ctlz || @@ -1276,6 +1374,18 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { } } + // Reduce add-carry of bools to logic: + // ((zext BoolX) + (zext BoolY)) >> 1 --> zext (BoolX && BoolY) + Value *BoolX, *BoolY; + if (ShAmtC == 1 && match(Op0, m_Add(m_Value(X), m_Value(Y))) && + match(X, m_ZExt(m_Value(BoolX))) && match(Y, m_ZExt(m_Value(BoolY))) && + BoolX->getType()->isIntOrIntVectorTy(1) && + BoolY->getType()->isIntOrIntVectorTy(1) && + (X->hasOneUse() || Y->hasOneUse() || Op0->hasOneUse())) { + Value *And = Builder.CreateAnd(BoolX, BoolY); + return new ZExtInst(And, Ty); + } + // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmtC), 0, &I)) { @@ -1285,13 +1395,15 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { } // Transform (x << y) >> y to x & (-1 >> y) - Value *X; if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) { Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); Value *Mask = Builder.CreateLShr(AllOnes, Op1); return BinaryOperator::CreateAnd(Mask, X); } + if (Instruction *Overflow = foldLShrOverflowBit(I)) + return Overflow; + return nullptr; } @@ -1469,8 +1581,11 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { return R; // See if we can turn a signed shr into an unsigned shr. - if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I)) - return BinaryOperator::CreateLShr(Op0, Op1); + if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I)) { + Instruction *Lshr = BinaryOperator::CreateLShr(Op0, Op1); + Lshr->setIsExact(I.isExact()); + return Lshr; + } // ashr (xor %x, -1), %y --> xor (ashr %x, %y), -1 if (match(Op0, m_OneUse(m_Not(m_Value(X))))) { |
