diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2017-01-02 19:17:04 +0000 |
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2017-01-02 19:17:04 +0000 |
| commit | b915e9e0fc85ba6f398b3fab0db6a81a8913af94 (patch) | |
| tree | 98b8f811c7aff2547cab8642daf372d6c59502fb /lib/Transforms/InstCombine | |
| parent | 6421cca32f69ac849537a3cff78c352195e99f1b (diff) | |
Notes
Diffstat (limited to 'lib/Transforms/InstCombine')
| -rw-r--r-- | lib/Transforms/InstCombine/CMakeLists.txt | 5 | ||||
| -rw-r--r-- | lib/Transforms/InstCombine/InstCombineAddSub.cpp | 125 | ||||
| -rw-r--r-- | lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 326 | ||||
| -rw-r--r-- | lib/Transforms/InstCombine/InstCombineCalls.cpp | 717 | ||||
| -rw-r--r-- | lib/Transforms/InstCombine/InstCombineCasts.cpp | 381 | ||||
| -rw-r--r-- | lib/Transforms/InstCombine/InstCombineCompares.cpp | 4346 | ||||
| -rw-r--r-- | lib/Transforms/InstCombine/InstCombineInternal.h | 237 | ||||
| -rw-r--r-- | lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp | 92 | ||||
| -rw-r--r-- | lib/Transforms/InstCombine/InstCombineMulDivRem.cpp | 130 | ||||
| -rw-r--r-- | lib/Transforms/InstCombine/InstCombinePHI.cpp | 38 | ||||
| -rw-r--r-- | lib/Transforms/InstCombine/InstCombineSelect.cpp | 529 | ||||
| -rw-r--r-- | lib/Transforms/InstCombine/InstCombineShifts.cpp | 337 | ||||
| -rw-r--r-- | lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp | 315 | ||||
| -rw-r--r-- | lib/Transforms/InstCombine/InstCombineVectorOps.cpp | 203 | ||||
| -rw-r--r-- | lib/Transforms/InstCombine/InstructionCombining.cpp | 262 |
15 files changed, 4793 insertions, 3250 deletions
diff --git a/lib/Transforms/InstCombine/CMakeLists.txt b/lib/Transforms/InstCombine/CMakeLists.txt index 0ed8e6273dbc..5cbe804ce3ec 100644 --- a/lib/Transforms/InstCombine/CMakeLists.txt +++ b/lib/Transforms/InstCombine/CMakeLists.txt @@ -16,6 +16,7 @@ add_llvm_library(LLVMInstCombine ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms/InstCombine - ) -add_dependencies(LLVMInstCombine intrinsics_gen) + DEPENDS + intrinsics_gen + ) diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 221a22007173..3bbc70ab21c6 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1035,7 +1035,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(), - I.hasNoUnsignedWrap(), DL, TLI, DT, AC)) + I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // (A*B)+(A*C) -> A*(B+C) etc @@ -1047,6 +1047,16 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // X + (signbit) --> X ^ signbit if (Val->isSignBit()) return BinaryOperator::CreateXor(LHS, RHS); + + // Is this add the last step in a convoluted sext? + Value *X; + const APInt *C; + if (match(LHS, m_ZExt(m_Xor(m_Value(X), m_APInt(C)))) && + C->isMinSignedValue() && + C->sext(LHS->getType()->getScalarSizeInBits()) == *Val) { + // add(zext(xor i16 X, -32768), -32768) --> sext X + return CastInst::Create(Instruction::SExt, X, LHS->getType()); + } } // FIXME: Use the match above instead of dyn_cast to allow these transforms @@ -1144,7 +1154,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return replaceInstUsesWith(I, V); // A+B --> A|B iff A and B have no bits set in common. - if (haveNoCommonBitsSet(LHS, RHS, DL, AC, &I, DT)) + if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT)) return BinaryOperator::CreateOr(LHS, RHS); if (Constant *CRHS = dyn_cast<Constant>(RHS)) { @@ -1216,15 +1226,16 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (SExtInst *LHSConv = dyn_cast<SExtInst>(LHS)) { // (add (sext x), cst) --> (sext (add x, cst')) if (ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS)) { - Constant *CI = - ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); - if (LHSConv->hasOneUse() && - ConstantExpr::getSExt(CI, I.getType()) == RHSC && - WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) { - // Insert the new, smaller add. - Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), - CI, "addconv"); - return new SExtInst(NewAdd, I.getType()); + if (LHSConv->hasOneUse()) { + Constant *CI = + ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); + if (ConstantExpr::getSExt(CI, I.getType()) == RHSC && + WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) { + // Insert the new, smaller add. + Value *NewAdd = + Builder->CreateNSWAdd(LHSConv->getOperand(0), CI, "addconv"); + return new SExtInst(NewAdd, I.getType()); + } } } @@ -1246,6 +1257,44 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } } + // Check for (add (zext x), y), see if we can merge this into an + // integer add followed by a zext. + if (auto *LHSConv = dyn_cast<ZExtInst>(LHS)) { + // (add (zext x), cst) --> (zext (add x, cst')) + if (ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS)) { + if (LHSConv->hasOneUse()) { + Constant *CI = + ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); + if (ConstantExpr::getZExt(CI, I.getType()) == RHSC && + computeOverflowForUnsignedAdd(LHSConv->getOperand(0), CI, &I) == + OverflowResult::NeverOverflows) { + // Insert the new, smaller add. + Value *NewAdd = + Builder->CreateNUWAdd(LHSConv->getOperand(0), CI, "addconv"); + return new ZExtInst(NewAdd, I.getType()); + } + } + } + + // (add (zext x), (zext y)) --> (zext (add int x, y)) + if (auto *RHSConv = dyn_cast<ZExtInst>(RHS)) { + // Only do this if x/y have the same type, if at last one of them has a + // single use (so we don't increase the number of zexts), and if the + // integer add will not overflow. + if (LHSConv->getOperand(0)->getType() == + RHSConv->getOperand(0)->getType() && + (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && + computeOverflowForUnsignedAdd(LHSConv->getOperand(0), + RHSConv->getOperand(0), + &I) == OverflowResult::NeverOverflows) { + // Insert the new integer add. + Value *NewAdd = Builder->CreateNUWAdd( + LHSConv->getOperand(0), RHSConv->getOperand(0), "addconv"); + return new ZExtInst(NewAdd, I.getType()); + } + } + } + // (add (xor A, B) (and A, B)) --> (or A, B) { Value *A = nullptr, *B = nullptr; @@ -1307,7 +1356,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = - SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL, TLI, DT, AC)) + SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); if (isa<Constant>(RHS)) { @@ -1483,7 +1532,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = SimplifySubInst(Op0, Op1, I.hasNoSignedWrap(), - I.hasNoUnsignedWrap(), DL, TLI, DT, AC)) + I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // (A*B)-(A*C) -> A*(B-C) etc @@ -1544,34 +1593,35 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return CastInst::CreateZExtOrBitCast(X, Op1->getType()); } - if (ConstantInt *C = dyn_cast<ConstantInt>(Op0)) { + const APInt *Op0C; + if (match(Op0, m_APInt(Op0C))) { + unsigned BitWidth = I.getType()->getScalarSizeInBits(); + // -(X >>u 31) -> (X >>s 31) // -(X >>s 31) -> (X >>u 31) - if (C->isZero()) { + if (*Op0C == 0) { Value *X; - ConstantInt *CI; - if (match(Op1, m_LShr(m_Value(X), m_ConstantInt(CI))) && - // Verify we are shifting out everything but the sign bit. - CI->getValue() == I.getType()->getPrimitiveSizeInBits() - 1) - return BinaryOperator::CreateAShr(X, CI); - - if (match(Op1, m_AShr(m_Value(X), m_ConstantInt(CI))) && - // Verify we are shifting out everything but the sign bit. - CI->getValue() == I.getType()->getPrimitiveSizeInBits() - 1) - return BinaryOperator::CreateLShr(X, CI); + const APInt *ShAmt; + if (match(Op1, m_LShr(m_Value(X), m_APInt(ShAmt))) && + *ShAmt == BitWidth - 1) { + Value *ShAmtOp = cast<Instruction>(Op1)->getOperand(1); + return BinaryOperator::CreateAShr(X, ShAmtOp); + } + if (match(Op1, m_AShr(m_Value(X), m_APInt(ShAmt))) && + *ShAmt == BitWidth - 1) { + Value *ShAmtOp = cast<Instruction>(Op1)->getOperand(1); + return BinaryOperator::CreateLShr(X, ShAmtOp); + } } // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known // zero. - APInt IntVal = C->getValue(); - if ((IntVal + 1).isPowerOf2()) { - unsigned BitWidth = I.getType()->getScalarSizeInBits(); + if ((*Op0C + 1).isPowerOf2()) { APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); computeKnownBits(&I, KnownZero, KnownOne, 0, &I); - if ((IntVal | KnownZero).isAllOnesValue()) { - return BinaryOperator::CreateXor(Op1, C); - } + if ((*Op0C | KnownZero).isAllOnesValue()) + return BinaryOperator::CreateXor(Op1, Op0); } } @@ -1632,6 +1682,17 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (Value *XNeg = dyn_castNegVal(X)) return BinaryOperator::CreateShl(XNeg, Y); + // Subtracting -1/0 is the same as adding 1/0: + // sub [nsw] Op0, sext(bool Y) -> add [nsw] Op0, zext(bool Y) + // 'nuw' is dropped in favor of the canonical form. + if (match(Op1, m_SExt(m_Value(Y))) && + Y->getType()->getScalarSizeInBits() == 1) { + Value *Zext = Builder->CreateZExt(Y, I.getType()); + BinaryOperator *Add = BinaryOperator::CreateAdd(Op0, Zext); + Add->setHasNoSignedWrap(I.hasNoSignedWrap()); + return Add; + } + // X - A*-B -> X + A*B // X - -A*B -> X + A*B Value *A, *B; @@ -1682,7 +1743,7 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = - SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) + SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // fsub nsz 0, X ==> fsub nsz -0.0, X diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 1a6459b3d689..a59b43d6af5f 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -98,12 +98,11 @@ Value *InstCombiner::SimplifyBSwap(BinaryOperator &I) { IntegerType *ITy = dyn_cast<IntegerType>(I.getType()); // Can't do vectors. - if (I.getType()->isVectorTy()) return nullptr; + if (I.getType()->isVectorTy()) + return nullptr; // Can only do bitwise ops. - unsigned Op = I.getOpcode(); - if (Op != Instruction::And && Op != Instruction::Or && - Op != Instruction::Xor) + if (!I.isBitwiseLogicOp()) return nullptr; Value *OldLHS = I.getOperand(0); @@ -132,14 +131,7 @@ Value *InstCombiner::SimplifyBSwap(BinaryOperator &I) { Value *NewRHS = IsBswapRHS ? IntrRHS->getOperand(0) : Builder->getInt(ConstRHS->getValue().byteSwap()); - Value *BinOp = nullptr; - if (Op == Instruction::And) - BinOp = Builder->CreateAnd(NewLHS, NewRHS); - else if (Op == Instruction::Or) - BinOp = Builder->CreateOr(NewLHS, NewRHS); - else //if (Op == Instruction::Xor) - BinOp = Builder->CreateXor(NewLHS, NewRHS); - + Value *BinOp = Builder->CreateBinOp(I.getOpcode(), NewLHS, NewRHS); Function *F = Intrinsic::getDeclaration(I.getModule(), Intrinsic::bswap, ITy); return Builder->CreateCall(F, BinOp); } @@ -283,51 +275,31 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op, } /// Emit a computation of: (V >= Lo && V < Hi) if Inside is true, otherwise -/// (V < Lo || V >= Hi). In practice, we emit the more efficient -/// (V-Lo) \<u Hi-Lo. This method expects that Lo <= Hi. isSigned indicates -/// whether to treat the V, Lo and HI as signed or not. IB is the location to -/// insert new instructions. -Value *InstCombiner::InsertRangeTest(Value *V, Constant *Lo, Constant *Hi, +/// (V < Lo || V >= Hi). This method expects that Lo <= Hi. IsSigned indicates +/// whether to treat V, Lo, and Hi as signed or not. +Value *InstCombiner::insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi, bool isSigned, bool Inside) { - assert(cast<ConstantInt>(ConstantExpr::getICmp((isSigned ? - ICmpInst::ICMP_SLE:ICmpInst::ICMP_ULE), Lo, Hi))->getZExtValue() && + assert((isSigned ? Lo.sle(Hi) : Lo.ule(Hi)) && "Lo is not <= Hi in range emission code!"); - if (Inside) { - if (Lo == Hi) // Trivially false. - return Builder->getFalse(); - - // V >= Min && V < Hi --> V < Hi - if (cast<ConstantInt>(Lo)->isMinValue(isSigned)) { - ICmpInst::Predicate pred = (isSigned ? - ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT); - return Builder->CreateICmp(pred, V, Hi); - } - - // Emit V-Lo <u Hi-Lo - Constant *NegLo = ConstantExpr::getNeg(Lo); - Value *Add = Builder->CreateAdd(V, NegLo, V->getName()+".off"); - Constant *UpperBound = ConstantExpr::getAdd(NegLo, Hi); - return Builder->CreateICmpULT(Add, UpperBound); - } - - if (Lo == Hi) // Trivially true. - return Builder->getTrue(); + Type *Ty = V->getType(); + if (Lo == Hi) + return Inside ? ConstantInt::getFalse(Ty) : ConstantInt::getTrue(Ty); - // V < Min || V >= Hi -> V > Hi-1 - Hi = SubOne(cast<ConstantInt>(Hi)); - if (cast<ConstantInt>(Lo)->isMinValue(isSigned)) { - ICmpInst::Predicate pred = (isSigned ? - ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT); - return Builder->CreateICmp(pred, V, Hi); + // V >= Min && V < Hi --> V < Hi + // V < Min || V >= Hi --> V >= Hi + ICmpInst::Predicate Pred = Inside ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE; + if (isSigned ? Lo.isMinSignedValue() : Lo.isMinValue()) { + Pred = isSigned ? ICmpInst::getSignedPredicate(Pred) : Pred; + return Builder->CreateICmp(Pred, V, ConstantInt::get(Ty, Hi)); } - // Emit V-Lo >u Hi-1-Lo - // Note that Hi has already had one subtracted from it, above. - ConstantInt *NegLo = cast<ConstantInt>(ConstantExpr::getNeg(Lo)); - Value *Add = Builder->CreateAdd(V, NegLo, V->getName()+".off"); - Constant *LowerBound = ConstantExpr::getAdd(NegLo, Hi); - return Builder->CreateICmpUGT(Add, LowerBound); + // V >= Lo && V < Hi --> V - Lo u< Hi - Lo + // V < Lo || V >= Hi --> V - Lo u>= Hi - Lo + Value *VMinusLo = + Builder->CreateSub(V, ConstantInt::get(Ty, Lo), V->getName() + ".off"); + Constant *HiMinusLo = ConstantInt::get(Ty, Hi - Lo); + return Builder->CreateICmp(Pred, VMinusLo, HiMinusLo); } /// Returns true iff Val consists of one contiguous run of 1s with any number @@ -524,53 +496,6 @@ static unsigned conjugateICmpMask(unsigned Mask) { return NewMask; } -/// Decompose an icmp into the form ((X & Y) pred Z) if possible. -/// The returned predicate is either == or !=. Returns false if -/// decomposition fails. -static bool decomposeBitTestICmp(const ICmpInst *I, ICmpInst::Predicate &Pred, - Value *&X, Value *&Y, Value *&Z) { - ConstantInt *C = dyn_cast<ConstantInt>(I->getOperand(1)); - if (!C) - return false; - - switch (I->getPredicate()) { - default: - return false; - case ICmpInst::ICMP_SLT: - // X < 0 is equivalent to (X & SignBit) != 0. - if (!C->isZero()) - return false; - Y = ConstantInt::get(I->getContext(), APInt::getSignBit(C->getBitWidth())); - Pred = ICmpInst::ICMP_NE; - break; - case ICmpInst::ICMP_SGT: - // X > -1 is equivalent to (X & SignBit) == 0. - if (!C->isAllOnesValue()) - return false; - Y = ConstantInt::get(I->getContext(), APInt::getSignBit(C->getBitWidth())); - Pred = ICmpInst::ICMP_EQ; - break; - case ICmpInst::ICMP_ULT: - // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0. - if (!C->getValue().isPowerOf2()) - return false; - Y = ConstantInt::get(I->getContext(), -C->getValue()); - Pred = ICmpInst::ICMP_EQ; - break; - case ICmpInst::ICMP_UGT: - // X >u 2^n-1 is equivalent to (X & ~(2^n-1)) != 0. - if (!(C->getValue() + 1).isPowerOf2()) - return false; - Y = ConstantInt::get(I->getContext(), ~C->getValue()); - Pred = ICmpInst::ICMP_NE; - break; - } - - X = I->getOperand(0); - Z = ConstantInt::getNullValue(C->getType()); - return true; -} - /// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) /// Return the set of pattern classes (from MaskedICmpType) /// that both LHS and RHS satisfy. @@ -1001,7 +926,8 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { if (LHSCst == SubOne(RHSCst)) // (X != 13 & X u< 14) -> X < 13 return Builder->CreateICmpULT(Val, LHSCst); if (LHSCst->isNullValue()) // (X != 0 & X u< 14) -> X-1 u< 13 - return InsertRangeTest(Val, AddOne(LHSCst), RHSCst, false, true); + return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(), + false, true); break; // (X != 13 & X u< 15) -> no change case ICmpInst::ICMP_SLT: if (LHSCst == SubOne(RHSCst)) // (X != 13 & X s< 14) -> X < 13 @@ -1065,7 +991,8 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { return Builder->CreateICmp(LHSCC, Val, RHSCst); break; // (X u> 13 & X != 15) -> no change case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) <u 1 - return InsertRangeTest(Val, AddOne(LHSCst), RHSCst, false, true); + return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(), + false, true); case ICmpInst::ICMP_SLT: // (X u> 13 & X s< 15) -> no change break; } @@ -1083,7 +1010,8 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { return Builder->CreateICmp(LHSCC, Val, RHSCst); break; // (X s> 13 & X != 15) -> no change case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) s< 1 - return InsertRangeTest(Val, AddOne(LHSCst), RHSCst, true, true); + return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(), + true, true); case ICmpInst::ICMP_ULT: // (X s> 13 & X u< 15) -> no change break; } @@ -1170,34 +1098,73 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I, return BinaryOperator::CreateNot(LogicOp); } - // De Morgan's Law in disguise: - // (zext(bool A) ^ 1) & (zext(bool B) ^ 1) -> zext(~(A | B)) - // (zext(bool A) ^ 1) | (zext(bool B) ^ 1) -> zext(~(A & B)) - Value *A = nullptr; - Value *B = nullptr; - ConstantInt *C1 = nullptr; - if (match(Op0, m_OneUse(m_Xor(m_ZExt(m_Value(A)), m_ConstantInt(C1)))) && - match(Op1, m_OneUse(m_Xor(m_ZExt(m_Value(B)), m_Specific(C1))))) { - // TODO: This check could be loosened to handle different type sizes. - // Alternatively, we could fix the definition of m_Not to recognize a not - // operation hidden by a zext? - if (A->getType()->isIntegerTy(1) && B->getType()->isIntegerTy(1) && - C1->isOne()) { - Value *LogicOp = Builder->CreateBinOp(Opcode, A, B, - I.getName() + ".demorgan"); - Value *Not = Builder->CreateNot(LogicOp); - return CastInst::CreateZExtOrBitCast(Not, I.getType()); + return nullptr; +} + +bool InstCombiner::shouldOptimizeCast(CastInst *CI) { + Value *CastSrc = CI->getOperand(0); + + // Noop casts and casts of constants should be eliminated trivially. + if (CI->getSrcTy() == CI->getDestTy() || isa<Constant>(CastSrc)) + return false; + + // If this cast is paired with another cast that can be eliminated, we prefer + // to have it eliminated. + if (const auto *PrecedingCI = dyn_cast<CastInst>(CastSrc)) + if (isEliminableCastPair(PrecedingCI, CI)) + return false; + + // If this is a vector sext from a compare, then we don't want to break the + // idiom where each element of the extended vector is either zero or all ones. + if (CI->getOpcode() == Instruction::SExt && + isa<CmpInst>(CastSrc) && CI->getDestTy()->isVectorTy()) + return false; + + return true; +} + +/// Fold {and,or,xor} (cast X), C. +static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast, + InstCombiner::BuilderTy *Builder) { + Constant *C; + if (!match(Logic.getOperand(1), m_Constant(C))) + return nullptr; + + auto LogicOpc = Logic.getOpcode(); + Type *DestTy = Logic.getType(); + Type *SrcTy = Cast->getSrcTy(); + + // If the first operand is bitcast, move the logic operation ahead of the + // bitcast (do the logic operation in the original type). This can eliminate + // bitcasts and allow combines that would otherwise be impeded by the bitcast. + Value *X; + if (match(Cast, m_BitCast(m_Value(X)))) { + Value *NewConstant = ConstantExpr::getBitCast(C, SrcTy); + Value *NewOp = Builder->CreateBinOp(LogicOpc, X, NewConstant); + return CastInst::CreateBitOrPointerCast(NewOp, DestTy); + } + + // Similarly, move the logic operation ahead of a zext if the constant is + // unchanged in the smaller source type. Performing the logic in a smaller + // type may provide more information to later folds, and the smaller logic + // instruction may be cheaper (particularly in the case of vectors). + if (match(Cast, m_OneUse(m_ZExt(m_Value(X))))) { + Constant *TruncC = ConstantExpr::getTrunc(C, SrcTy); + Constant *ZextTruncC = ConstantExpr::getZExt(TruncC, DestTy); + if (ZextTruncC == C) { + // LogicOpc (zext X), C --> zext (LogicOpc X, C) + Value *NewOp = Builder->CreateBinOp(LogicOpc, X, TruncC); + return new ZExtInst(NewOp, DestTy); } } return nullptr; } +/// Fold {and,or,xor} (cast X), Y. Instruction *InstCombiner::foldCastedBitwiseLogic(BinaryOperator &I) { auto LogicOpc = I.getOpcode(); - assert((LogicOpc == Instruction::And || LogicOpc == Instruction::Or || - LogicOpc == Instruction::Xor) && - "Unexpected opcode for bitwise logic folding"); + assert(I.isBitwiseLogicOp() && "Unexpected opcode for bitwise logic folding"); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); CastInst *Cast0 = dyn_cast<CastInst>(Op0); @@ -1211,18 +1178,8 @@ Instruction *InstCombiner::foldCastedBitwiseLogic(BinaryOperator &I) { if (!SrcTy->isIntOrIntVectorTy()) return nullptr; - // If one operand is a bitcast and the other is a constant, move the logic - // operation ahead of the bitcast. That is, do the logic operation in the - // original type. This can eliminate useless bitcasts and allow normal - // combines that would otherwise be impeded by the bitcast. Canonicalization - // ensures that if there is a constant operand, it will be the second operand. - Value *BC = nullptr; - Constant *C = nullptr; - if ((match(Op0, m_BitCast(m_Value(BC))) && match(Op1, m_Constant(C)))) { - Value *NewConstant = ConstantExpr::getBitCast(C, SrcTy); - Value *NewOp = Builder->CreateBinOp(LogicOpc, BC, NewConstant, I.getName()); - return CastInst::CreateBitOrPointerCast(NewOp, DestTy); - } + if (Instruction *Ret = foldLogicCastConstant(I, Cast0, Builder)) + return Ret; CastInst *Cast1 = dyn_cast<CastInst>(Op1); if (!Cast1) @@ -1237,12 +1194,8 @@ Instruction *InstCombiner::foldCastedBitwiseLogic(BinaryOperator &I) { Value *Cast0Src = Cast0->getOperand(0); Value *Cast1Src = Cast1->getOperand(0); - // fold (logic (cast A), (cast B)) -> (cast (logic A, B)) - - // Only do this if the casts both really cause code to be generated. - if ((!isa<ICmpInst>(Cast0Src) || !isa<ICmpInst>(Cast1Src)) && - ShouldOptimizeCast(CastOpcode, Cast0Src, DestTy) && - ShouldOptimizeCast(CastOpcode, Cast1Src, DestTy)) { + // fold logic(cast(A), cast(B)) -> cast(logic(A, B)) + if (shouldOptimizeCast(Cast0) && shouldOptimizeCast(Cast1)) { Value *NewOp = Builder->CreateBinOp(LogicOpc, Cast0Src, Cast1Src, I.getName()); return CastInst::Create(CastOpcode, NewOp, DestTy); @@ -1301,10 +1254,13 @@ static Instruction *foldBoolSextMaskToSelect(BinaryOperator &I) { Value *Zero = Constant::getNullValue(Op0->getType()); return SelectInst::Create(X, Zero, Op1); } - + return nullptr; } +// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches +// here. We should standardize that construct where it is needed or choose some +// other way to ensure that commutated variants of patterns are not missed. Instruction *InstCombiner::visitAnd(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -1312,7 +1268,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyAndInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifyAndInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // (A|B)&(A|C) -> A|(B&C) etc @@ -1503,8 +1459,9 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { return BinaryOperator::CreateAnd(A, B); // ((~A) ^ B) & (A | B) -> (A & B) + // ((~A) ^ B) & (B | A) -> (A & B) if (match(Op0, m_Xor(m_Not(m_Value(A)), m_Value(B))) && - match(Op1, m_Or(m_Specific(A), m_Specific(B)))) + match(Op1, m_c_Or(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateAnd(A, B); } @@ -1697,17 +1654,17 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Value *Mask = nullptr; Value *Masked = nullptr; if (LAnd->getOperand(0) == RAnd->getOperand(0) && - isKnownToBeAPowerOfTwo(LAnd->getOperand(1), DL, false, 0, AC, CxtI, - DT) && - isKnownToBeAPowerOfTwo(RAnd->getOperand(1), DL, false, 0, AC, CxtI, - DT)) { + isKnownToBeAPowerOfTwo(LAnd->getOperand(1), DL, false, 0, &AC, CxtI, + &DT) && + isKnownToBeAPowerOfTwo(RAnd->getOperand(1), DL, false, 0, &AC, CxtI, + &DT)) { Mask = Builder->CreateOr(LAnd->getOperand(1), RAnd->getOperand(1)); Masked = Builder->CreateAnd(LAnd->getOperand(0), Mask); } else if (LAnd->getOperand(1) == RAnd->getOperand(1) && - isKnownToBeAPowerOfTwo(LAnd->getOperand(0), DL, false, 0, AC, - CxtI, DT) && - isKnownToBeAPowerOfTwo(RAnd->getOperand(0), DL, false, 0, AC, - CxtI, DT)) { + isKnownToBeAPowerOfTwo(LAnd->getOperand(0), DL, false, 0, &AC, + CxtI, &DT) && + isKnownToBeAPowerOfTwo(RAnd->getOperand(0), DL, false, 0, &AC, + CxtI, &DT)) { Mask = Builder->CreateOr(LAnd->getOperand(0), RAnd->getOperand(0)); Masked = Builder->CreateAnd(LAnd->getOperand(1), Mask); } @@ -1825,7 +1782,7 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // E.g. (icmp sgt x, n) | (icmp slt x, 0) --> icmp ugt x, n if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/true)) return V; - + // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). if (!LHSCst || !RHSCst) return nullptr; @@ -1943,7 +1900,8 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // this can cause overflow. if (RHSCst->isMaxValue(false)) return LHS; - return InsertRangeTest(Val, LHSCst, AddOne(RHSCst), false, false); + return insertRangeTest(Val, LHSCst->getValue(), RHSCst->getValue() + 1, + false, false); case ICmpInst::ICMP_SGT: // (X u< 13 | X s> 15) -> no change break; case ICmpInst::ICMP_NE: // (X u< 13 | X != 15) -> X != 15 @@ -1963,7 +1921,8 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // this can cause overflow. if (RHSCst->isMaxValue(true)) return LHS; - return InsertRangeTest(Val, LHSCst, AddOne(RHSCst), true, false); + return insertRangeTest(Val, LHSCst->getValue(), RHSCst->getValue() + 1, + true, false); case ICmpInst::ICMP_UGT: // (X s< 13 | X u> 15) -> no change break; case ICmpInst::ICMP_NE: // (X s< 13 | X != 15) -> X != 15 @@ -2119,6 +2078,9 @@ Instruction *InstCombiner::FoldXorWithConstants(BinaryOperator &I, Value *Op, return nullptr; } +// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches +// here. We should standardize that construct where it is needed or choose some +// other way to ensure that commutated variants of patterns are not missed. Instruction *InstCombiner::visitOr(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -2126,7 +2088,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyOrInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifyOrInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // (A&B)|(A&C) -> A&(B|C) etc @@ -2208,14 +2170,17 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { match(Op1, m_Not(m_Specific(A)))) return BinaryOperator::CreateOr(Builder->CreateNot(A), B); - // (A & (~B)) | (A ^ B) -> (A ^ B) - if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && + // (A & ~B) | (A ^ B) -> (A ^ B) + // (~B & A) | (A ^ B) -> (A ^ B) + if (match(Op0, m_c_And(m_Value(A), m_Not(m_Value(B)))) && match(Op1, m_Xor(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateXor(A, B); - // (A ^ B) | ( A & (~B)) -> (A ^ B) - if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && - match(Op1, m_And(m_Specific(A), m_Not(m_Specific(B))))) + // Commute the 'or' operands. + // (A ^ B) | (A & ~B) -> (A ^ B) + // (A ^ B) | (~B & A) -> (A ^ B) + if (match(Op1, m_c_And(m_Value(A), m_Not(m_Value(B)))) && + match(Op0, m_Xor(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateXor(A, B); // (A & C)|(B & D) @@ -2385,14 +2350,15 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { return BinaryOperator::CreateOr(Not, Op0); } - // (A & B) | ((~A) ^ B) -> (~A ^ B) - if (match(Op0, m_And(m_Value(A), m_Value(B))) && - match(Op1, m_Xor(m_Not(m_Specific(A)), m_Specific(B)))) - return BinaryOperator::CreateXor(Builder->CreateNot(A), B); - - // ((~A) ^ B) | (A & B) -> (~A ^ B) - if (match(Op0, m_Xor(m_Not(m_Value(A)), m_Value(B))) && - match(Op1, m_And(m_Specific(A), m_Specific(B)))) + // (A & B) | (~A ^ B) -> (~A ^ B) + // (A & B) | (B ^ ~A) -> (~A ^ B) + // (B & A) | (~A ^ B) -> (~A ^ B) + // (B & A) | (B ^ ~A) -> (~A ^ B) + // The match order is important: match the xor first because the 'not' + // operation defines 'A'. We do not need to match the xor as Op0 because the + // xor was canonicalized to Op1 above. + if (match(Op1, m_c_Xor(m_Not(m_Value(A)), m_Value(B))) && + match(Op0, m_c_And(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateXor(Builder->CreateNot(A), B); if (SwappedForXor) @@ -2472,6 +2438,9 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { return Changed ? &I : nullptr; } +// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches +// here. We should standardize that construct where it is needed or choose some +// other way to ensure that commutated variants of patterns are not missed. Instruction *InstCombiner::visitXor(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -2479,7 +2448,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyXorInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifyXorInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // (A&B)^(A&C) -> A&(B^C) etc @@ -2694,20 +2663,22 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { return BinaryOperator::CreateXor(A, B); } // (A | ~B) ^ (~A | B) -> A ^ B - if (match(Op0I, m_Or(m_Value(A), m_Not(m_Value(B)))) && - match(Op1I, m_Or(m_Not(m_Specific(A)), m_Specific(B)))) { + // (~B | A) ^ (~A | B) -> A ^ B + if (match(Op0I, m_c_Or(m_Value(A), m_Not(m_Value(B)))) && + match(Op1I, m_Or(m_Not(m_Specific(A)), m_Specific(B)))) return BinaryOperator::CreateXor(A, B); - } + // (~A | B) ^ (A | ~B) -> A ^ B if (match(Op0I, m_Or(m_Not(m_Value(A)), m_Value(B))) && match(Op1I, m_Or(m_Specific(A), m_Not(m_Specific(B))))) { return BinaryOperator::CreateXor(A, B); } // (A & ~B) ^ (~A & B) -> A ^ B - if (match(Op0I, m_And(m_Value(A), m_Not(m_Value(B)))) && - match(Op1I, m_And(m_Not(m_Specific(A)), m_Specific(B)))) { + // (~B & A) ^ (~A & B) -> A ^ B + if (match(Op0I, m_c_And(m_Value(A), m_Not(m_Value(B)))) && + match(Op1I, m_And(m_Not(m_Specific(A)), m_Specific(B)))) return BinaryOperator::CreateXor(A, B); - } + // (~A & B) ^ (A & ~B) -> A ^ B if (match(Op0I, m_And(m_Not(m_Value(A)), m_Value(B))) && match(Op1I, m_And(m_Specific(A), m_Not(m_Specific(B))))) { @@ -2743,9 +2714,10 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { return BinaryOperator::CreateOr(A, B); } - Value *A = nullptr, *B = nullptr; - // (A & ~B) ^ (~A) -> ~(A & B) - if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && + // (A & ~B) ^ ~A -> ~(A & B) + // (~B & A) ^ ~A -> ~(A & B) + Value *A, *B; + if (match(Op0, m_c_And(m_Value(A), m_Not(m_Value(B)))) && match(Op1, m_Not(m_Specific(A)))) return BinaryOperator::CreateNot(Builder->CreateAnd(A, B)); diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp index 8acff91345d6..92369bd70b13 100644 --- a/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -12,17 +12,47 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" -#include "llvm/IR/Dominators.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Statepoint.h" -#include "llvm/Transforms/Utils/BuildLibCalls.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SimplifyLibCalls.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <cstring> +#include <vector> + using namespace llvm; using namespace PatternMatch; @@ -79,8 +109,8 @@ static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) { } Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { - unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL, MI, AC, DT); - unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL, MI, AC, DT); + unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL, MI, &AC, &DT); + unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL, MI, &AC, &DT); unsigned MinAlign = std::min(DstAlign, SrcAlign); unsigned CopyAlign = MI->getAlignment(); @@ -162,10 +192,17 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { L->setAlignment(SrcAlign); if (CopyMD) L->setMetadata(LLVMContext::MD_tbaa, CopyMD); + MDNode *LoopMemParallelMD = + MI->getMetadata(LLVMContext::MD_mem_parallel_loop_access); + if (LoopMemParallelMD) + L->setMetadata(LLVMContext::MD_mem_parallel_loop_access, LoopMemParallelMD); + StoreInst *S = Builder->CreateStore(L, Dest, MI->isVolatile()); S->setAlignment(DstAlign); if (CopyMD) S->setMetadata(LLVMContext::MD_tbaa, CopyMD); + if (LoopMemParallelMD) + S->setMetadata(LLVMContext::MD_mem_parallel_loop_access, LoopMemParallelMD); // Set the size of the copy to 0, it will be deleted on the next iteration. MI->setArgOperand(2, Constant::getNullValue(MemOpLength->getType())); @@ -173,7 +210,7 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { } Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) { - unsigned Alignment = getKnownAlignment(MI->getDest(), DL, MI, AC, DT); + unsigned Alignment = getKnownAlignment(MI->getDest(), DL, MI, &AC, &DT); if (MI->getAlignment() < Alignment) { MI->setAlignment(ConstantInt::get(MI->getAlignmentType(), Alignment, false)); @@ -221,8 +258,7 @@ static Value *simplifyX86immShift(const IntrinsicInst &II, bool ShiftLeft = false; switch (II.getIntrinsicID()) { - default: - return nullptr; + default: llvm_unreachable("Unexpected intrinsic!"); case Intrinsic::x86_sse2_psra_d: case Intrinsic::x86_sse2_psra_w: case Intrinsic::x86_sse2_psrai_d: @@ -231,6 +267,16 @@ static Value *simplifyX86immShift(const IntrinsicInst &II, case Intrinsic::x86_avx2_psra_w: case Intrinsic::x86_avx2_psrai_d: case Intrinsic::x86_avx2_psrai_w: + case Intrinsic::x86_avx512_psra_q_128: + case Intrinsic::x86_avx512_psrai_q_128: + case Intrinsic::x86_avx512_psra_q_256: + case Intrinsic::x86_avx512_psrai_q_256: + case Intrinsic::x86_avx512_psra_d_512: + case Intrinsic::x86_avx512_psra_q_512: + case Intrinsic::x86_avx512_psra_w_512: + case Intrinsic::x86_avx512_psrai_d_512: + case Intrinsic::x86_avx512_psrai_q_512: + case Intrinsic::x86_avx512_psrai_w_512: LogicalShift = false; ShiftLeft = false; break; case Intrinsic::x86_sse2_psrl_d: @@ -245,6 +291,12 @@ static Value *simplifyX86immShift(const IntrinsicInst &II, case Intrinsic::x86_avx2_psrli_d: case Intrinsic::x86_avx2_psrli_q: case Intrinsic::x86_avx2_psrli_w: + case Intrinsic::x86_avx512_psrl_d_512: + case Intrinsic::x86_avx512_psrl_q_512: + case Intrinsic::x86_avx512_psrl_w_512: + case Intrinsic::x86_avx512_psrli_d_512: + case Intrinsic::x86_avx512_psrli_q_512: + case Intrinsic::x86_avx512_psrli_w_512: LogicalShift = true; ShiftLeft = false; break; case Intrinsic::x86_sse2_psll_d: @@ -259,6 +311,12 @@ static Value *simplifyX86immShift(const IntrinsicInst &II, case Intrinsic::x86_avx2_pslli_d: case Intrinsic::x86_avx2_pslli_q: case Intrinsic::x86_avx2_pslli_w: + case Intrinsic::x86_avx512_psll_d_512: + case Intrinsic::x86_avx512_psll_q_512: + case Intrinsic::x86_avx512_psll_w_512: + case Intrinsic::x86_avx512_pslli_d_512: + case Intrinsic::x86_avx512_pslli_q_512: + case Intrinsic::x86_avx512_pslli_w_512: LogicalShift = true; ShiftLeft = true; break; } @@ -334,10 +392,16 @@ static Value *simplifyX86varShift(const IntrinsicInst &II, bool ShiftLeft = false; switch (II.getIntrinsicID()) { - default: - return nullptr; + default: llvm_unreachable("Unexpected intrinsic!"); case Intrinsic::x86_avx2_psrav_d: case Intrinsic::x86_avx2_psrav_d_256: + case Intrinsic::x86_avx512_psrav_q_128: + case Intrinsic::x86_avx512_psrav_q_256: + case Intrinsic::x86_avx512_psrav_d_512: + case Intrinsic::x86_avx512_psrav_q_512: + case Intrinsic::x86_avx512_psrav_w_128: + case Intrinsic::x86_avx512_psrav_w_256: + case Intrinsic::x86_avx512_psrav_w_512: LogicalShift = false; ShiftLeft = false; break; @@ -345,6 +409,11 @@ static Value *simplifyX86varShift(const IntrinsicInst &II, case Intrinsic::x86_avx2_psrlv_d_256: case Intrinsic::x86_avx2_psrlv_q: case Intrinsic::x86_avx2_psrlv_q_256: + case Intrinsic::x86_avx512_psrlv_d_512: + case Intrinsic::x86_avx512_psrlv_q_512: + case Intrinsic::x86_avx512_psrlv_w_128: + case Intrinsic::x86_avx512_psrlv_w_256: + case Intrinsic::x86_avx512_psrlv_w_512: LogicalShift = true; ShiftLeft = false; break; @@ -352,6 +421,11 @@ static Value *simplifyX86varShift(const IntrinsicInst &II, case Intrinsic::x86_avx2_psllv_d_256: case Intrinsic::x86_avx2_psllv_q: case Intrinsic::x86_avx2_psllv_q_256: + case Intrinsic::x86_avx512_psllv_d_512: + case Intrinsic::x86_avx512_psllv_q_512: + case Intrinsic::x86_avx512_psllv_w_128: + case Intrinsic::x86_avx512_psllv_w_256: + case Intrinsic::x86_avx512_psllv_w_512: LogicalShift = true; ShiftLeft = true; break; @@ -400,7 +474,7 @@ static Value *simplifyX86varShift(const IntrinsicInst &II, // If all elements out of range or UNDEF, return vector of zeros/undefs. // ArithmeticShift should only hit this if they are all UNDEF. auto OutOfRange = [&](int Idx) { return (Idx < 0) || (BitWidth <= Idx); }; - if (llvm::all_of(ShiftAmts, OutOfRange)) { + if (all_of(ShiftAmts, OutOfRange)) { SmallVector<Constant *, 8> ConstantVec; for (int Idx : ShiftAmts) { if (Idx < 0) { @@ -547,7 +621,7 @@ static Value *simplifyX86extrq(IntrinsicInst &II, Value *Op0, // See if we're dealing with constant values. Constant *C0 = dyn_cast<Constant>(Op0); ConstantInt *CI0 = - C0 ? dyn_cast<ConstantInt>(C0->getAggregateElement((unsigned)0)) + C0 ? dyn_cast_or_null<ConstantInt>(C0->getAggregateElement((unsigned)0)) : nullptr; // Attempt to constant fold. @@ -630,7 +704,6 @@ static Value *simplifyX86extrq(IntrinsicInst &II, Value *Op0, static Value *simplifyX86insertq(IntrinsicInst &II, Value *Op0, Value *Op1, APInt APLength, APInt APIndex, InstCombiner::BuilderTy &Builder) { - // From AMD documentation: "The bit index and field length are each six bits // in length other bits of the field are ignored." APIndex = APIndex.zextOrTrunc(6); @@ -686,10 +759,10 @@ static Value *simplifyX86insertq(IntrinsicInst &II, Value *Op0, Value *Op1, Constant *C0 = dyn_cast<Constant>(Op0); Constant *C1 = dyn_cast<Constant>(Op1); ConstantInt *CI00 = - C0 ? dyn_cast<ConstantInt>(C0->getAggregateElement((unsigned)0)) + C0 ? dyn_cast_or_null<ConstantInt>(C0->getAggregateElement((unsigned)0)) : nullptr; ConstantInt *CI10 = - C1 ? dyn_cast<ConstantInt>(C1->getAggregateElement((unsigned)0)) + C1 ? dyn_cast_or_null<ConstantInt>(C1->getAggregateElement((unsigned)0)) : nullptr; // Constant Fold - insert bottom Length bits starting at the Index'th bit. @@ -732,11 +805,11 @@ static Value *simplifyX86pshufb(const IntrinsicInst &II, auto *VecTy = cast<VectorType>(II.getType()); auto *MaskEltTy = Type::getInt32Ty(II.getContext()); unsigned NumElts = VecTy->getNumElements(); - assert((NumElts == 16 || NumElts == 32) && + assert((NumElts == 16 || NumElts == 32 || NumElts == 64) && "Unexpected number of elements in shuffle mask!"); // Construct a shuffle mask from constant integers or UNDEFs. - Constant *Indexes[32] = {NULL}; + Constant *Indexes[64] = {nullptr}; // Each byte in the shuffle control mask forms an index to permute the // corresponding byte in the destination operand. @@ -776,12 +849,15 @@ static Value *simplifyX86vpermilvar(const IntrinsicInst &II, if (!V) return nullptr; + auto *VecTy = cast<VectorType>(II.getType()); auto *MaskEltTy = Type::getInt32Ty(II.getContext()); - unsigned NumElts = cast<VectorType>(V->getType())->getNumElements(); - assert(NumElts == 8 || NumElts == 4 || NumElts == 2); + unsigned NumElts = VecTy->getVectorNumElements(); + bool IsPD = VecTy->getScalarType()->isDoubleTy(); + unsigned NumLaneElts = IsPD ? 2 : 4; + assert(NumElts == 16 || NumElts == 8 || NumElts == 4 || NumElts == 2); // Construct a shuffle mask from constant integers or UNDEFs. - Constant *Indexes[8] = {NULL}; + Constant *Indexes[16] = {nullptr}; // The intrinsics only read one or two bits, clear the rest. for (unsigned I = 0; I < NumElts; ++I) { @@ -799,18 +875,13 @@ static Value *simplifyX86vpermilvar(const IntrinsicInst &II, // The PD variants uses bit 1 to select per-lane element index, so // shift down to convert to generic shuffle mask index. - if (II.getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_pd || - II.getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_pd_256) + if (IsPD) Index = Index.lshr(1); // The _256 variants are a bit trickier since the mask bits always index // into the corresponding 128 half. In order to convert to a generic // shuffle, we have to make that explicit. - if ((II.getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_ps_256 || - II.getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_pd_256) && - ((NumElts / 2) <= I)) { - Index += APInt(32, NumElts / 2); - } + Index += APInt(32, (I / NumLaneElts) * NumLaneElts); Indexes[I] = ConstantInt::get(MaskEltTy, Index); } @@ -831,10 +902,11 @@ static Value *simplifyX86vpermv(const IntrinsicInst &II, auto *VecTy = cast<VectorType>(II.getType()); auto *MaskEltTy = Type::getInt32Ty(II.getContext()); unsigned Size = VecTy->getNumElements(); - assert(Size == 8 && "Unexpected shuffle mask size"); + assert((Size == 4 || Size == 8 || Size == 16 || Size == 32 || Size == 64) && + "Unexpected shuffle mask size"); // Construct a shuffle mask from constant integers or UNDEFs. - Constant *Indexes[8] = {NULL}; + Constant *Indexes[64] = {nullptr}; for (unsigned I = 0; I < Size; ++I) { Constant *COp = V->getAggregateElement(I); @@ -846,8 +918,8 @@ static Value *simplifyX86vpermv(const IntrinsicInst &II, continue; } - APInt Index = cast<ConstantInt>(COp)->getValue(); - Index = Index.zextOrTrunc(32).getLoBits(3); + uint32_t Index = cast<ConstantInt>(COp)->getZExtValue(); + Index &= Size - 1; Indexes[I] = ConstantInt::get(MaskEltTy, Index); } @@ -962,6 +1034,36 @@ static Value *simplifyX86vpcom(const IntrinsicInst &II, return nullptr; } +// Emit a select instruction and appropriate bitcasts to help simplify +// masked intrinsics. +static Value *emitX86MaskSelect(Value *Mask, Value *Op0, Value *Op1, + InstCombiner::BuilderTy &Builder) { + unsigned VWidth = Op0->getType()->getVectorNumElements(); + + // If the mask is all ones we don't need the select. But we need to check + // only the bit thats will be used in case VWidth is less than 8. + if (auto *C = dyn_cast<ConstantInt>(Mask)) + if (C->getValue().zextOrTrunc(VWidth).isAllOnesValue()) + return Op0; + + auto *MaskTy = VectorType::get(Builder.getInt1Ty(), + cast<IntegerType>(Mask->getType())->getBitWidth()); + Mask = Builder.CreateBitCast(Mask, MaskTy); + + // If we have less than 8 elements, then the starting mask was an i8 and + // we need to extract down to the right number of elements. + if (VWidth < 8) { + uint32_t Indices[4]; + for (unsigned i = 0; i != VWidth; ++i) + Indices[i] = i; + Mask = Builder.CreateShuffleVector(Mask, Mask, + makeArrayRef(Indices, VWidth), + "extract"); + } + + return Builder.CreateSelect(Mask, Op0, Op1); +} + static Value *simplifyMinnumMaxnum(const IntrinsicInst &II) { Value *Arg0 = II.getArgOperand(0); Value *Arg1 = II.getArgOperand(1); @@ -1104,6 +1206,50 @@ static Instruction *simplifyMaskedScatter(IntrinsicInst &II, InstCombiner &IC) { return nullptr; } +static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombiner &IC) { + assert((II.getIntrinsicID() == Intrinsic::cttz || + II.getIntrinsicID() == Intrinsic::ctlz) && + "Expected cttz or ctlz intrinsic"); + Value *Op0 = II.getArgOperand(0); + // FIXME: Try to simplify vectors of integers. + auto *IT = dyn_cast<IntegerType>(Op0->getType()); + if (!IT) + return nullptr; + + unsigned BitWidth = IT->getBitWidth(); + APInt KnownZero(BitWidth, 0); + APInt KnownOne(BitWidth, 0); + IC.computeKnownBits(Op0, KnownZero, KnownOne, 0, &II); + + // Create a mask for bits above (ctlz) or below (cttz) the first known one. + bool IsTZ = II.getIntrinsicID() == Intrinsic::cttz; + unsigned NumMaskBits = IsTZ ? KnownOne.countTrailingZeros() + : KnownOne.countLeadingZeros(); + APInt Mask = IsTZ ? APInt::getLowBitsSet(BitWidth, NumMaskBits) + : APInt::getHighBitsSet(BitWidth, NumMaskBits); + + // If all bits above (ctlz) or below (cttz) the first known one are known + // zero, this value is constant. + // FIXME: This should be in InstSimplify because we're replacing an + // instruction with a constant. + if ((Mask & KnownZero) == Mask) { + auto *C = ConstantInt::get(IT, APInt(BitWidth, NumMaskBits)); + return IC.replaceInstUsesWith(II, C); + } + + // If the input to cttz/ctlz is known to be non-zero, + // then change the 'ZeroIsUndef' parameter to 'true' + // because we know the zero behavior can't affect the result. + if (KnownOne != 0 || isKnownNonZero(Op0, IC.getDataLayout())) { + if (!match(II.getArgOperand(1), m_One())) { + II.setOperand(1, IC.Builder->getTrue()); + return &II; + } + } + + return nullptr; +} + // TODO: If the x86 backend knew how to convert a bool vector mask back to an // XMM register mask efficiently, we could transform all x86 masked intrinsics // to LLVM masked intrinsics and remove the x86 masked intrinsic defs. @@ -1243,16 +1389,15 @@ Instruction *InstCombiner::visitVACopyInst(VACopyInst &I) { Instruction *InstCombiner::visitCallInst(CallInst &CI) { auto Args = CI.arg_operands(); if (Value *V = SimplifyCall(CI.getCalledValue(), Args.begin(), Args.end(), DL, - TLI, DT, AC)) + &TLI, &DT, &AC)) return replaceInstUsesWith(CI, V); - if (isFreeCall(&CI, TLI)) + if (isFreeCall(&CI, &TLI)) return visitFree(CI); // If the caller function is nounwind, mark the call as nounwind, even if the // callee isn't. - if (CI.getParent()->getParent()->doesNotThrow() && - !CI.doesNotThrow()) { + if (CI.getFunction()->doesNotThrow() && !CI.doesNotThrow()) { CI.setDoesNotThrow(); return &CI; } @@ -1323,26 +1468,15 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { APInt DemandedElts = APInt::getLowBitsSet(Width, DemandedWidth); return SimplifyDemandedVectorElts(Op, DemandedElts, UndefElts); }; - auto SimplifyDemandedVectorEltsHigh = [this](Value *Op, unsigned Width, - unsigned DemandedWidth) { - APInt UndefElts(Width, 0); - APInt DemandedElts = APInt::getHighBitsSet(Width, DemandedWidth); - return SimplifyDemandedVectorElts(Op, DemandedElts, UndefElts); - }; switch (II->getIntrinsicID()) { default: break; - case Intrinsic::objectsize: { - uint64_t Size; - if (getObjectSize(II->getArgOperand(0), Size, DL, TLI)) { - APInt APSize(II->getType()->getIntegerBitWidth(), Size); - // Equality check to be sure that `Size` can fit in a value of type - // `II->getType()` - if (APSize == Size) - return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), APSize)); - } + case Intrinsic::objectsize: + if (ConstantInt *N = + lowerObjectSizeCall(II, DL, &TLI, /*MustSucceed=*/false)) + return replaceInstUsesWith(CI, N); return nullptr; - } + case Intrinsic::bswap: { Value *IIOperand = II->getArgOperand(0); Value *X = nullptr; @@ -1397,41 +1531,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { II->getArgOperand(0)); } break; - case Intrinsic::cttz: { - // If all bits below the first known one are known zero, - // this value is constant. - IntegerType *IT = dyn_cast<IntegerType>(II->getArgOperand(0)->getType()); - // FIXME: Try to simplify vectors of integers. - if (!IT) break; - uint32_t BitWidth = IT->getBitWidth(); - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne, 0, II); - unsigned TrailingZeros = KnownOne.countTrailingZeros(); - APInt Mask(APInt::getLowBitsSet(BitWidth, TrailingZeros)); - if ((Mask & KnownZero) == Mask) - return replaceInstUsesWith(CI, ConstantInt::get(IT, - APInt(BitWidth, TrailingZeros))); - - } - break; - case Intrinsic::ctlz: { - // If all bits above the first known one are known zero, - // this value is constant. - IntegerType *IT = dyn_cast<IntegerType>(II->getArgOperand(0)->getType()); - // FIXME: Try to simplify vectors of integers. - if (!IT) break; - uint32_t BitWidth = IT->getBitWidth(); - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne, 0, II); - unsigned LeadingZeros = KnownOne.countLeadingZeros(); - APInt Mask(APInt::getHighBitsSet(BitWidth, LeadingZeros)); - if ((Mask & KnownZero) == Mask) - return replaceInstUsesWith(CI, ConstantInt::get(IT, - APInt(BitWidth, LeadingZeros))); - } + case Intrinsic::cttz: + case Intrinsic::ctlz: + if (auto *I = foldCttzCtlz(*II, *this)) + return I; break; case Intrinsic::uadd_with_overflow: @@ -1446,7 +1550,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { II->setArgOperand(1, LHS); return II; } - // fall through + LLVM_FALLTHROUGH; case Intrinsic::usub_with_overflow: case Intrinsic::ssub_with_overflow: { @@ -1480,8 +1584,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::ppc_altivec_lvx: case Intrinsic::ppc_altivec_lvxl: // Turn PPC lvx -> load if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL, II, AC, DT) >= - 16) { + if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL, II, &AC, + &DT) >= 16) { Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(II->getType())); return new LoadInst(Ptr); @@ -1497,8 +1601,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::ppc_altivec_stvx: case Intrinsic::ppc_altivec_stvxl: // Turn stvx -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16, DL, II, AC, DT) >= - 16) { + if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16, DL, II, &AC, + &DT) >= 16) { Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType()); Value *Ptr = Builder->CreateBitCast(II->getArgOperand(1), OpPtrTy); @@ -1514,8 +1618,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } case Intrinsic::ppc_qpx_qvlfs: // Turn PPC QPX qvlfs -> load if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL, II, AC, DT) >= - 16) { + if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL, II, &AC, + &DT) >= 16) { Type *VTy = VectorType::get(Builder->getFloatTy(), II->getType()->getVectorNumElements()); Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), @@ -1526,8 +1630,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; case Intrinsic::ppc_qpx_qvlfd: // Turn PPC QPX qvlfd -> load if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), 32, DL, II, AC, DT) >= - 32) { + if (getOrEnforceKnownAlignment(II->getArgOperand(0), 32, DL, II, &AC, + &DT) >= 32) { Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(II->getType())); return new LoadInst(Ptr); @@ -1535,8 +1639,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; case Intrinsic::ppc_qpx_qvstfs: // Turn PPC QPX qvstfs -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16, DL, II, AC, DT) >= - 16) { + if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16, DL, II, &AC, + &DT) >= 16) { Type *VTy = VectorType::get(Builder->getFloatTy(), II->getArgOperand(0)->getType()->getVectorNumElements()); Value *TOp = Builder->CreateFPTrunc(II->getArgOperand(0), VTy); @@ -1547,8 +1651,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; case Intrinsic::ppc_qpx_qvstfd: // Turn PPC QPX qvstfd -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(1), 32, DL, II, AC, DT) >= - 32) { + if (getOrEnforceKnownAlignment(II->getArgOperand(1), 32, DL, II, &AC, + &DT) >= 32) { Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType()); Value *Ptr = Builder->CreateBitCast(II->getArgOperand(1), OpPtrTy); @@ -1607,7 +1711,23 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_sse2_cvtsd2si: case Intrinsic::x86_sse2_cvtsd2si64: case Intrinsic::x86_sse2_cvttsd2si: - case Intrinsic::x86_sse2_cvttsd2si64: { + case Intrinsic::x86_sse2_cvttsd2si64: + case Intrinsic::x86_avx512_vcvtss2si32: + case Intrinsic::x86_avx512_vcvtss2si64: + case Intrinsic::x86_avx512_vcvtss2usi32: + case Intrinsic::x86_avx512_vcvtss2usi64: + case Intrinsic::x86_avx512_vcvtsd2si32: + case Intrinsic::x86_avx512_vcvtsd2si64: + case Intrinsic::x86_avx512_vcvtsd2usi32: + case Intrinsic::x86_avx512_vcvtsd2usi64: + case Intrinsic::x86_avx512_cvttss2si: + case Intrinsic::x86_avx512_cvttss2si64: + case Intrinsic::x86_avx512_cvttss2usi: + case Intrinsic::x86_avx512_cvttss2usi64: + case Intrinsic::x86_avx512_cvttsd2si: + case Intrinsic::x86_avx512_cvttsd2si64: + case Intrinsic::x86_avx512_cvttsd2usi: + case Intrinsic::x86_avx512_cvttsd2usi64: { // These intrinsics only demand the 0th element of their input vectors. If // we can simplify the input based on that, do so now. Value *Arg = II->getArgOperand(0); @@ -1654,7 +1774,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_sse2_ucomigt_sd: case Intrinsic::x86_sse2_ucomile_sd: case Intrinsic::x86_sse2_ucomilt_sd: - case Intrinsic::x86_sse2_ucomineq_sd: { + case Intrinsic::x86_sse2_ucomineq_sd: + case Intrinsic::x86_avx512_vcomi_ss: + case Intrinsic::x86_avx512_vcomi_sd: + case Intrinsic::x86_avx512_mask_cmp_ss: + case Intrinsic::x86_avx512_mask_cmp_sd: { // These intrinsics only demand the 0th element of their input vectors. If // we can simplify the input based on that, do so now. bool MadeChange = false; @@ -1674,50 +1798,155 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } - case Intrinsic::x86_sse_add_ss: - case Intrinsic::x86_sse_sub_ss: - case Intrinsic::x86_sse_mul_ss: - case Intrinsic::x86_sse_div_ss: + case Intrinsic::x86_avx512_mask_add_ps_512: + case Intrinsic::x86_avx512_mask_div_ps_512: + case Intrinsic::x86_avx512_mask_mul_ps_512: + case Intrinsic::x86_avx512_mask_sub_ps_512: + case Intrinsic::x86_avx512_mask_add_pd_512: + case Intrinsic::x86_avx512_mask_div_pd_512: + case Intrinsic::x86_avx512_mask_mul_pd_512: + case Intrinsic::x86_avx512_mask_sub_pd_512: + // If the rounding mode is CUR_DIRECTION(4) we can turn these into regular + // IR operations. + if (auto *R = dyn_cast<ConstantInt>(II->getArgOperand(4))) { + if (R->getValue() == 4) { + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + + Value *V; + switch (II->getIntrinsicID()) { + default: llvm_unreachable("Case stmts out of sync!"); + case Intrinsic::x86_avx512_mask_add_ps_512: + case Intrinsic::x86_avx512_mask_add_pd_512: + V = Builder->CreateFAdd(Arg0, Arg1); + break; + case Intrinsic::x86_avx512_mask_sub_ps_512: + case Intrinsic::x86_avx512_mask_sub_pd_512: + V = Builder->CreateFSub(Arg0, Arg1); + break; + case Intrinsic::x86_avx512_mask_mul_ps_512: + case Intrinsic::x86_avx512_mask_mul_pd_512: + V = Builder->CreateFMul(Arg0, Arg1); + break; + case Intrinsic::x86_avx512_mask_div_ps_512: + case Intrinsic::x86_avx512_mask_div_pd_512: + V = Builder->CreateFDiv(Arg0, Arg1); + break; + } + + // Create a select for the masking. + V = emitX86MaskSelect(II->getArgOperand(3), V, II->getArgOperand(2), + *Builder); + return replaceInstUsesWith(*II, V); + } + } + break; + + case Intrinsic::x86_avx512_mask_add_ss_round: + case Intrinsic::x86_avx512_mask_div_ss_round: + case Intrinsic::x86_avx512_mask_mul_ss_round: + case Intrinsic::x86_avx512_mask_sub_ss_round: + case Intrinsic::x86_avx512_mask_add_sd_round: + case Intrinsic::x86_avx512_mask_div_sd_round: + case Intrinsic::x86_avx512_mask_mul_sd_round: + case Intrinsic::x86_avx512_mask_sub_sd_round: + // If the rounding mode is CUR_DIRECTION(4) we can turn these into regular + // IR operations. + if (auto *R = dyn_cast<ConstantInt>(II->getArgOperand(4))) { + if (R->getValue() == 4) { + // Extract the element as scalars. + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + Value *LHS = Builder->CreateExtractElement(Arg0, (uint64_t)0); + Value *RHS = Builder->CreateExtractElement(Arg1, (uint64_t)0); + + Value *V; + switch (II->getIntrinsicID()) { + default: llvm_unreachable("Case stmts out of sync!"); + case Intrinsic::x86_avx512_mask_add_ss_round: + case Intrinsic::x86_avx512_mask_add_sd_round: + V = Builder->CreateFAdd(LHS, RHS); + break; + case Intrinsic::x86_avx512_mask_sub_ss_round: + case Intrinsic::x86_avx512_mask_sub_sd_round: + V = Builder->CreateFSub(LHS, RHS); + break; + case Intrinsic::x86_avx512_mask_mul_ss_round: + case Intrinsic::x86_avx512_mask_mul_sd_round: + V = Builder->CreateFMul(LHS, RHS); + break; + case Intrinsic::x86_avx512_mask_div_ss_round: + case Intrinsic::x86_avx512_mask_div_sd_round: + V = Builder->CreateFDiv(LHS, RHS); + break; + } + + // Handle the masking aspect of the intrinsic. + Value *Mask = II->getArgOperand(3); + auto *C = dyn_cast<ConstantInt>(Mask); + // We don't need a select if we know the mask bit is a 1. + if (!C || !C->getValue()[0]) { + // Cast the mask to an i1 vector and then extract the lowest element. + auto *MaskTy = VectorType::get(Builder->getInt1Ty(), + cast<IntegerType>(Mask->getType())->getBitWidth()); + Mask = Builder->CreateBitCast(Mask, MaskTy); + Mask = Builder->CreateExtractElement(Mask, (uint64_t)0); + // Extract the lowest element from the passthru operand. + Value *Passthru = Builder->CreateExtractElement(II->getArgOperand(2), + (uint64_t)0); + V = Builder->CreateSelect(Mask, V, Passthru); + } + + // Insert the result back into the original argument 0. + V = Builder->CreateInsertElement(Arg0, V, (uint64_t)0); + + return replaceInstUsesWith(*II, V); + } + } + LLVM_FALLTHROUGH; + + // X86 scalar intrinsics simplified with SimplifyDemandedVectorElts. + case Intrinsic::x86_avx512_mask_max_ss_round: + case Intrinsic::x86_avx512_mask_min_ss_round: + case Intrinsic::x86_avx512_mask_max_sd_round: + case Intrinsic::x86_avx512_mask_min_sd_round: + case Intrinsic::x86_avx512_mask_vfmadd_ss: + case Intrinsic::x86_avx512_mask_vfmadd_sd: + case Intrinsic::x86_avx512_maskz_vfmadd_ss: + case Intrinsic::x86_avx512_maskz_vfmadd_sd: + case Intrinsic::x86_avx512_mask3_vfmadd_ss: + case Intrinsic::x86_avx512_mask3_vfmadd_sd: + case Intrinsic::x86_avx512_mask3_vfmsub_ss: + case Intrinsic::x86_avx512_mask3_vfmsub_sd: + case Intrinsic::x86_avx512_mask3_vfnmsub_ss: + case Intrinsic::x86_avx512_mask3_vfnmsub_sd: + case Intrinsic::x86_fma_vfmadd_ss: + case Intrinsic::x86_fma_vfmsub_ss: + case Intrinsic::x86_fma_vfnmadd_ss: + case Intrinsic::x86_fma_vfnmsub_ss: + case Intrinsic::x86_fma_vfmadd_sd: + case Intrinsic::x86_fma_vfmsub_sd: + case Intrinsic::x86_fma_vfnmadd_sd: + case Intrinsic::x86_fma_vfnmsub_sd: + case Intrinsic::x86_sse_cmp_ss: case Intrinsic::x86_sse_min_ss: case Intrinsic::x86_sse_max_ss: - case Intrinsic::x86_sse_cmp_ss: - case Intrinsic::x86_sse2_add_sd: - case Intrinsic::x86_sse2_sub_sd: - case Intrinsic::x86_sse2_mul_sd: - case Intrinsic::x86_sse2_div_sd: + case Intrinsic::x86_sse2_cmp_sd: case Intrinsic::x86_sse2_min_sd: case Intrinsic::x86_sse2_max_sd: - case Intrinsic::x86_sse2_cmp_sd: { - // These intrinsics only demand the lowest element of the second input - // vector. - Value *Arg1 = II->getArgOperand(1); - unsigned VWidth = Arg1->getType()->getVectorNumElements(); - if (Value *V = SimplifyDemandedVectorEltsLow(Arg1, VWidth, 1)) { - II->setArgOperand(1, V); - return II; - } - break; - } - case Intrinsic::x86_sse41_round_ss: - case Intrinsic::x86_sse41_round_sd: { - // These intrinsics demand the upper elements of the first input vector and - // the lowest element of the second input vector. - bool MadeChange = false; - Value *Arg0 = II->getArgOperand(0); - Value *Arg1 = II->getArgOperand(1); - unsigned VWidth = Arg0->getType()->getVectorNumElements(); - if (Value *V = SimplifyDemandedVectorEltsHigh(Arg0, VWidth, VWidth - 1)) { - II->setArgOperand(0, V); - MadeChange = true; - } - if (Value *V = SimplifyDemandedVectorEltsLow(Arg1, VWidth, 1)) { - II->setArgOperand(1, V); - MadeChange = true; - } - if (MadeChange) - return II; - break; + case Intrinsic::x86_sse41_round_sd: + case Intrinsic::x86_xop_vfrcz_ss: + case Intrinsic::x86_xop_vfrcz_sd: { + unsigned VWidth = II->getType()->getVectorNumElements(); + APInt UndefElts(VWidth, 0); + APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); + if (Value *V = SimplifyDemandedVectorElts(II, AllOnesEltMask, UndefElts)) { + if (V != II) + return replaceInstUsesWith(*II, V); + return II; + } + break; } // Constant fold ashr( <A x Bi>, Ci ). @@ -1727,18 +1956,29 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_sse2_psrai_w: case Intrinsic::x86_avx2_psrai_d: case Intrinsic::x86_avx2_psrai_w: + case Intrinsic::x86_avx512_psrai_q_128: + case Intrinsic::x86_avx512_psrai_q_256: + case Intrinsic::x86_avx512_psrai_d_512: + case Intrinsic::x86_avx512_psrai_q_512: + case Intrinsic::x86_avx512_psrai_w_512: case Intrinsic::x86_sse2_psrli_d: case Intrinsic::x86_sse2_psrli_q: case Intrinsic::x86_sse2_psrli_w: case Intrinsic::x86_avx2_psrli_d: case Intrinsic::x86_avx2_psrli_q: case Intrinsic::x86_avx2_psrli_w: + case Intrinsic::x86_avx512_psrli_d_512: + case Intrinsic::x86_avx512_psrli_q_512: + case Intrinsic::x86_avx512_psrli_w_512: case Intrinsic::x86_sse2_pslli_d: case Intrinsic::x86_sse2_pslli_q: case Intrinsic::x86_sse2_pslli_w: case Intrinsic::x86_avx2_pslli_d: case Intrinsic::x86_avx2_pslli_q: case Intrinsic::x86_avx2_pslli_w: + case Intrinsic::x86_avx512_pslli_d_512: + case Intrinsic::x86_avx512_pslli_q_512: + case Intrinsic::x86_avx512_pslli_w_512: if (Value *V = simplifyX86immShift(*II, *Builder)) return replaceInstUsesWith(*II, V); break; @@ -1747,18 +1987,29 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_sse2_psra_w: case Intrinsic::x86_avx2_psra_d: case Intrinsic::x86_avx2_psra_w: + case Intrinsic::x86_avx512_psra_q_128: + case Intrinsic::x86_avx512_psra_q_256: + case Intrinsic::x86_avx512_psra_d_512: + case Intrinsic::x86_avx512_psra_q_512: + case Intrinsic::x86_avx512_psra_w_512: case Intrinsic::x86_sse2_psrl_d: case Intrinsic::x86_sse2_psrl_q: case Intrinsic::x86_sse2_psrl_w: case Intrinsic::x86_avx2_psrl_d: case Intrinsic::x86_avx2_psrl_q: case Intrinsic::x86_avx2_psrl_w: + case Intrinsic::x86_avx512_psrl_d_512: + case Intrinsic::x86_avx512_psrl_q_512: + case Intrinsic::x86_avx512_psrl_w_512: case Intrinsic::x86_sse2_psll_d: case Intrinsic::x86_sse2_psll_q: case Intrinsic::x86_sse2_psll_w: case Intrinsic::x86_avx2_psll_d: case Intrinsic::x86_avx2_psll_q: - case Intrinsic::x86_avx2_psll_w: { + case Intrinsic::x86_avx2_psll_w: + case Intrinsic::x86_avx512_psll_d_512: + case Intrinsic::x86_avx512_psll_q_512: + case Intrinsic::x86_avx512_psll_w_512: { if (Value *V = simplifyX86immShift(*II, *Builder)) return replaceInstUsesWith(*II, V); @@ -1780,16 +2031,50 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx2_psllv_d_256: case Intrinsic::x86_avx2_psllv_q: case Intrinsic::x86_avx2_psllv_q_256: + case Intrinsic::x86_avx512_psllv_d_512: + case Intrinsic::x86_avx512_psllv_q_512: + case Intrinsic::x86_avx512_psllv_w_128: + case Intrinsic::x86_avx512_psllv_w_256: + case Intrinsic::x86_avx512_psllv_w_512: case Intrinsic::x86_avx2_psrav_d: case Intrinsic::x86_avx2_psrav_d_256: + case Intrinsic::x86_avx512_psrav_q_128: + case Intrinsic::x86_avx512_psrav_q_256: + case Intrinsic::x86_avx512_psrav_d_512: + case Intrinsic::x86_avx512_psrav_q_512: + case Intrinsic::x86_avx512_psrav_w_128: + case Intrinsic::x86_avx512_psrav_w_256: + case Intrinsic::x86_avx512_psrav_w_512: case Intrinsic::x86_avx2_psrlv_d: case Intrinsic::x86_avx2_psrlv_d_256: case Intrinsic::x86_avx2_psrlv_q: case Intrinsic::x86_avx2_psrlv_q_256: + case Intrinsic::x86_avx512_psrlv_d_512: + case Intrinsic::x86_avx512_psrlv_q_512: + case Intrinsic::x86_avx512_psrlv_w_128: + case Intrinsic::x86_avx512_psrlv_w_256: + case Intrinsic::x86_avx512_psrlv_w_512: if (Value *V = simplifyX86varShift(*II, *Builder)) return replaceInstUsesWith(*II, V); break; + case Intrinsic::x86_sse2_pmulu_dq: + case Intrinsic::x86_sse41_pmuldq: + case Intrinsic::x86_avx2_pmul_dq: + case Intrinsic::x86_avx2_pmulu_dq: + case Intrinsic::x86_avx512_pmul_dq_512: + case Intrinsic::x86_avx512_pmulu_dq_512: { + unsigned VWidth = II->getType()->getVectorNumElements(); + APInt UndefElts(VWidth, 0); + APInt DemandedElts = APInt::getAllOnesValue(VWidth); + if (Value *V = SimplifyDemandedVectorElts(II, DemandedElts, UndefElts)) { + if (V != II) + return replaceInstUsesWith(*II, V); + return II; + } + break; + } + case Intrinsic::x86_sse41_insertps: if (Value *V = simplifyX86insertps(*II, *Builder)) return replaceInstUsesWith(*II, V); @@ -1807,10 +2092,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // See if we're dealing with constant values. Constant *C1 = dyn_cast<Constant>(Op1); ConstantInt *CILength = - C1 ? dyn_cast<ConstantInt>(C1->getAggregateElement((unsigned)0)) + C1 ? dyn_cast_or_null<ConstantInt>(C1->getAggregateElement((unsigned)0)) : nullptr; ConstantInt *CIIndex = - C1 ? dyn_cast<ConstantInt>(C1->getAggregateElement((unsigned)1)) + C1 ? dyn_cast_or_null<ConstantInt>(C1->getAggregateElement((unsigned)1)) : nullptr; // Attempt to simplify to a constant, shuffle vector or EXTRQI call. @@ -1870,7 +2155,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // See if we're dealing with constant values. Constant *C1 = dyn_cast<Constant>(Op1); ConstantInt *CI11 = - C1 ? dyn_cast<ConstantInt>(C1->getAggregateElement((unsigned)1)) + C1 ? dyn_cast_or_null<ConstantInt>(C1->getAggregateElement((unsigned)1)) : nullptr; // Attempt to simplify to a constant, shuffle vector or INSERTQI call. @@ -1964,14 +2249,17 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_ssse3_pshuf_b_128: case Intrinsic::x86_avx2_pshuf_b: + case Intrinsic::x86_avx512_pshuf_b_512: if (Value *V = simplifyX86pshufb(*II, *Builder)) return replaceInstUsesWith(*II, V); break; case Intrinsic::x86_avx_vpermilvar_ps: case Intrinsic::x86_avx_vpermilvar_ps_256: + case Intrinsic::x86_avx512_vpermilvar_ps_512: case Intrinsic::x86_avx_vpermilvar_pd: case Intrinsic::x86_avx_vpermilvar_pd_256: + case Intrinsic::x86_avx512_vpermilvar_pd_512: if (Value *V = simplifyX86vpermilvar(*II, *Builder)) return replaceInstUsesWith(*II, V); break; @@ -1982,6 +2270,28 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, V); break; + case Intrinsic::x86_avx512_mask_permvar_df_256: + case Intrinsic::x86_avx512_mask_permvar_df_512: + case Intrinsic::x86_avx512_mask_permvar_di_256: + case Intrinsic::x86_avx512_mask_permvar_di_512: + case Intrinsic::x86_avx512_mask_permvar_hi_128: + case Intrinsic::x86_avx512_mask_permvar_hi_256: + case Intrinsic::x86_avx512_mask_permvar_hi_512: + case Intrinsic::x86_avx512_mask_permvar_qi_128: + case Intrinsic::x86_avx512_mask_permvar_qi_256: + case Intrinsic::x86_avx512_mask_permvar_qi_512: + case Intrinsic::x86_avx512_mask_permvar_sf_256: + case Intrinsic::x86_avx512_mask_permvar_sf_512: + case Intrinsic::x86_avx512_mask_permvar_si_256: + case Intrinsic::x86_avx512_mask_permvar_si_512: + if (Value *V = simplifyX86vpermv(*II, *Builder)) { + // We simplified the permuting, now create a select for the masking. + V = emitX86MaskSelect(II->getArgOperand(3), V, II->getArgOperand(2), + *Builder); + return replaceInstUsesWith(*II, V); + } + break; + case Intrinsic::x86_avx_vperm2f128_pd_256: case Intrinsic::x86_avx_vperm2f128_ps_256: case Intrinsic::x86_avx_vperm2f128_si_256: @@ -2104,7 +2414,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::arm_neon_vst2lane: case Intrinsic::arm_neon_vst3lane: case Intrinsic::arm_neon_vst4lane: { - unsigned MemAlign = getKnownAlignment(II->getArgOperand(0), DL, II, AC, DT); + unsigned MemAlign = + getKnownAlignment(II->getArgOperand(0), DL, II, &AC, &DT); unsigned AlignArg = II->getNumArgOperands() - 1; ConstantInt *IntrAlign = dyn_cast<ConstantInt>(II->getArgOperand(AlignArg)); if (IntrAlign && IntrAlign->getZExtValue() < MemAlign) { @@ -2194,6 +2505,85 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::amdgcn_class: { + enum { + S_NAN = 1 << 0, // Signaling NaN + Q_NAN = 1 << 1, // Quiet NaN + N_INFINITY = 1 << 2, // Negative infinity + N_NORMAL = 1 << 3, // Negative normal + N_SUBNORMAL = 1 << 4, // Negative subnormal + N_ZERO = 1 << 5, // Negative zero + P_ZERO = 1 << 6, // Positive zero + P_SUBNORMAL = 1 << 7, // Positive subnormal + P_NORMAL = 1 << 8, // Positive normal + P_INFINITY = 1 << 9 // Positive infinity + }; + + const uint32_t FullMask = S_NAN | Q_NAN | N_INFINITY | N_NORMAL | + N_SUBNORMAL | N_ZERO | P_ZERO | P_SUBNORMAL | P_NORMAL | P_INFINITY; + + Value *Src0 = II->getArgOperand(0); + Value *Src1 = II->getArgOperand(1); + const ConstantInt *CMask = dyn_cast<ConstantInt>(Src1); + if (!CMask) { + if (isa<UndefValue>(Src0)) + return replaceInstUsesWith(*II, UndefValue::get(II->getType())); + + if (isa<UndefValue>(Src1)) + return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), false)); + break; + } + + uint32_t Mask = CMask->getZExtValue(); + + // If all tests are made, it doesn't matter what the value is. + if ((Mask & FullMask) == FullMask) + return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), true)); + + if ((Mask & FullMask) == 0) + return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), false)); + + if (Mask == (S_NAN | Q_NAN)) { + // Equivalent of isnan. Replace with standard fcmp. + Value *FCmp = Builder->CreateFCmpUNO(Src0, Src0); + FCmp->takeName(II); + return replaceInstUsesWith(*II, FCmp); + } + + const ConstantFP *CVal = dyn_cast<ConstantFP>(Src0); + if (!CVal) { + if (isa<UndefValue>(Src0)) + return replaceInstUsesWith(*II, UndefValue::get(II->getType())); + + // Clamp mask to used bits + if ((Mask & FullMask) != Mask) { + CallInst *NewCall = Builder->CreateCall(II->getCalledFunction(), + { Src0, ConstantInt::get(Src1->getType(), Mask & FullMask) } + ); + + NewCall->takeName(II); + return replaceInstUsesWith(*II, NewCall); + } + + break; + } + + const APFloat &Val = CVal->getValueAPF(); + + bool Result = + ((Mask & S_NAN) && Val.isNaN() && Val.isSignaling()) || + ((Mask & Q_NAN) && Val.isNaN() && !Val.isSignaling()) || + ((Mask & N_INFINITY) && Val.isInfinity() && Val.isNegative()) || + ((Mask & N_NORMAL) && Val.isNormal() && Val.isNegative()) || + ((Mask & N_SUBNORMAL) && Val.isDenormal() && Val.isNegative()) || + ((Mask & N_ZERO) && Val.isZero() && Val.isNegative()) || + ((Mask & P_ZERO) && Val.isZero() && !Val.isNegative()) || + ((Mask & P_SUBNORMAL) && Val.isDenormal() && !Val.isNegative()) || + ((Mask & P_NORMAL) && Val.isNormal() && !Val.isNegative()) || + ((Mask & P_INFINITY) && Val.isInfinity() && !Val.isNegative()); + + return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), Result)); + } case Intrinsic::stackrestore: { // If the save is right next to the restore, remove the restore. This can // happen when variable allocas are DCE'd. @@ -2243,6 +2633,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } case Intrinsic::lifetime_start: + // Asan needs to poison memory to detect invalid access which is possible + // even for empty lifetime range. + if (II->getFunction()->hasFnAttribute(Attribute::SanitizeAddress)) + break; + if (removeTriviallyEmptyRange(*II, Intrinsic::lifetime_start, Intrinsic::lifetime_end, *this)) return nullptr; @@ -2283,7 +2678,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { RHS->getType()->isPointerTy() && cast<Constant>(RHS)->isNullValue()) { LoadInst* LI = cast<LoadInst>(LHS); - if (isValidAssumeForContext(II, LI, DT)) { + if (isValidAssumeForContext(II, LI, &DT)) { MDNode *MD = MDNode::get(II->getContext(), None); LI->setMetadata(LLVMContext::MD_nonnull, MD); return eraseInstFromFunction(*II); @@ -2329,7 +2724,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, ConstantPointerNull::get(PT)); // isKnownNonNull -> nonnull attribute - if (isKnownNonNullAt(DerivedPtr, II, DT)) + if (isKnownNonNullAt(DerivedPtr, II, &DT)) II->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull); } @@ -2389,7 +2784,7 @@ Instruction *InstCombiner::tryOptimizeCall(CallInst *CI) { auto InstCombineRAUW = [this](Instruction *From, Value *With) { replaceInstUsesWith(*From, With); }; - LibCallSimplifier Simplifier(DL, TLI, InstCombineRAUW); + LibCallSimplifier Simplifier(DL, &TLI, InstCombineRAUW); if (Value *With = Simplifier.optimizeCall(CI)) { ++NumSimplified; return CI->use_empty() ? CI : replaceInstUsesWith(*CI, With); @@ -2477,8 +2872,7 @@ static IntrinsicInst *findInitTrampoline(Value *Callee) { /// Improvements for call and invoke instructions. Instruction *InstCombiner::visitCallSite(CallSite CS) { - - if (isAllocLikeFn(CS.getInstruction(), TLI)) + if (isAllocLikeFn(CS.getInstruction(), &TLI)) return visitAllocSite(*CS.getInstruction()); bool Changed = false; @@ -2492,7 +2886,7 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { for (Value *V : CS.args()) { if (V->getType()->isPointerTy() && !CS.paramHasAttr(ArgNo + 1, Attribute::NonNull) && - isKnownNonNullAt(V, CS.getInstruction(), DT)) + isKnownNonNullAt(V, CS.getInstruction(), &DT)) Indices.push_back(ArgNo + 1); ArgNo++; } @@ -2613,14 +3007,14 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { /// If the callee is a constexpr cast of a function, attempt to move the cast to /// the arguments of the call/invoke. bool InstCombiner::transformConstExprCastCall(CallSite CS) { - Function *Callee = - dyn_cast<Function>(CS.getCalledValue()->stripPointerCasts()); + auto *Callee = dyn_cast<Function>(CS.getCalledValue()->stripPointerCasts()); if (!Callee) return false; - // The prototype of thunks are a lie, don't try to directly call such - // functions. + + // The prototype of a thunk is a lie. Don't directly call such a function. if (Callee->hasFnAttribute("thunk")) return false; + Instruction *Caller = CS.getInstruction(); const AttributeSet &CallerPAL = CS.getAttributes(); @@ -2842,8 +3236,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { CallInst *CI = cast<CallInst>(Caller); NC = Builder->CreateCall(Callee, Args, OpBundles); NC->takeName(CI); - if (CI->isTailCall()) - cast<CallInst>(NC)->setTailCall(); + cast<CallInst>(NC)->setTailCallKind(CI->getTailCallKind()); cast<CallInst>(NC)->setCallingConv(CI->getCallingConv()); cast<CallInst>(NC)->setAttributes(NewCallerPAL); } @@ -2966,7 +3359,7 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, ++Idx; ++I; - } while (1); + } while (true); } // Add any function attributes. @@ -3001,7 +3394,7 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, ++Idx; ++I; - } while (1); + } while (true); } // Replace the trampoline call with a direct call. Let the generic @@ -3027,10 +3420,10 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, cast<InvokeInst>(NewCaller)->setAttributes(NewPAL); } else { NewCaller = CallInst::Create(NewCallee, NewArgs, OpBundles); - if (cast<CallInst>(Caller)->isTailCall()) - cast<CallInst>(NewCaller)->setTailCall(); - cast<CallInst>(NewCaller)-> - setCallingConv(cast<CallInst>(Caller)->getCallingConv()); + cast<CallInst>(NewCaller)->setTailCallKind( + cast<CallInst>(Caller)->getTailCallKind()); + cast<CallInst>(NewCaller)->setCallingConv( + cast<CallInst>(Caller)->getCallingConv()); cast<CallInst>(NewCaller)->setAttributes(NewPAL); } diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index 20556157188f..e74b590e2b7c 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm/ADT/SetVector.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/PatternMatch.h" @@ -161,8 +162,8 @@ Value *InstCombiner::EvaluateInDifferentType(Value *V, Type *Ty, if (Constant *C = dyn_cast<Constant>(V)) { C = ConstantExpr::getIntegerCast(C, Ty, isSigned /*Sext or ZExt*/); // If we got a constantexpr back, try to simplify it with DL info. - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) - C = ConstantFoldConstantExpression(CE, DL, TLI); + if (Constant *FoldedC = ConstantFoldConstant(C, DL, &TLI)) + C = FoldedC; return C; } @@ -227,20 +228,14 @@ Value *InstCombiner::EvaluateInDifferentType(Value *V, Type *Ty, return InsertNewInstWith(Res, *I); } +Instruction::CastOps InstCombiner::isEliminableCastPair(const CastInst *CI1, + const CastInst *CI2) { + Type *SrcTy = CI1->getSrcTy(); + Type *MidTy = CI1->getDestTy(); + Type *DstTy = CI2->getDestTy(); -/// This function is a wrapper around CastInst::isEliminableCastPair. It -/// simply extracts arguments and returns what that function returns. -static Instruction::CastOps -isEliminableCastPair(const CastInst *CI, ///< First cast instruction - unsigned opcode, ///< Opcode for the second cast - Type *DstTy, ///< Target type for the second cast - const DataLayout &DL) { - Type *SrcTy = CI->getOperand(0)->getType(); // A from above - Type *MidTy = CI->getType(); // B from above - - // Get the opcodes of the two Cast instructions - Instruction::CastOps firstOp = Instruction::CastOps(CI->getOpcode()); - Instruction::CastOps secondOp = Instruction::CastOps(opcode); + Instruction::CastOps firstOp = Instruction::CastOps(CI1->getOpcode()); + Instruction::CastOps secondOp = Instruction::CastOps(CI2->getOpcode()); Type *SrcIntPtrTy = SrcTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(SrcTy) : nullptr; Type *MidIntPtrTy = @@ -260,54 +255,28 @@ isEliminableCastPair(const CastInst *CI, ///< First cast instruction return Instruction::CastOps(Res); } -/// Return true if the cast from "V to Ty" actually results in any code being -/// generated and is interesting to optimize out. -/// If the cast can be eliminated by some other simple transformation, we prefer -/// to do the simplification first. -bool InstCombiner::ShouldOptimizeCast(Instruction::CastOps opc, const Value *V, - Type *Ty) { - // Noop casts and casts of constants should be eliminated trivially. - if (V->getType() == Ty || isa<Constant>(V)) return false; - - // If this is another cast that can be eliminated, we prefer to have it - // eliminated. - if (const CastInst *CI = dyn_cast<CastInst>(V)) - if (isEliminableCastPair(CI, opc, Ty, DL)) - return false; - - // If this is a vector sext from a compare, then we don't want to break the - // idiom where each element of the extended vector is either zero or all ones. - if (opc == Instruction::SExt && isa<CmpInst>(V) && Ty->isVectorTy()) - return false; - - return true; -} - - /// @brief Implement the transforms common to all CastInst visitors. Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { Value *Src = CI.getOperand(0); - // Many cases of "cast of a cast" are eliminable. If it's eliminable we just - // eliminate it now. - if (CastInst *CSrc = dyn_cast<CastInst>(Src)) { // A->B->C cast - if (Instruction::CastOps opc = - isEliminableCastPair(CSrc, CI.getOpcode(), CI.getType(), DL)) { + // Try to eliminate a cast of a cast. + if (auto *CSrc = dyn_cast<CastInst>(Src)) { // A->B->C cast + if (Instruction::CastOps NewOpc = isEliminableCastPair(CSrc, &CI)) { // The first cast (CSrc) is eliminable so we need to fix up or replace // the second cast (CI). CSrc will then have a good chance of being dead. - return CastInst::Create(opc, CSrc->getOperand(0), CI.getType()); + return CastInst::Create(NewOpc, CSrc->getOperand(0), CI.getType()); } } - // If we are casting a select then fold the cast into the select - if (SelectInst *SI = dyn_cast<SelectInst>(Src)) + // If we are casting a select, then fold the cast into the select. + if (auto *SI = dyn_cast<SelectInst>(Src)) if (Instruction *NV = FoldOpIntoSelect(CI, SI)) return NV; - // If we are casting a PHI then fold the cast into the PHI + // If we are casting a PHI, then fold the cast into the PHI. if (isa<PHINode>(Src)) { - // We don't do this if this would create a PHI node with an illegal type if - // it is currently legal. + // Don't do this if it would create a PHI node with an illegal type from a + // legal type. if (!Src->getType()->isIntegerTy() || !CI.getType()->isIntegerTy() || ShouldChangeType(CI.getType(), Src->getType())) if (Instruction *NV = FoldOpIntoPhi(CI)) @@ -474,19 +443,39 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC, return ExtractElementInst::Create(VecInput, IC.Builder->getInt32(Elt)); } +/// Try to narrow the width of bitwise logic instructions with constants. +Instruction *InstCombiner::shrinkBitwiseLogic(TruncInst &Trunc) { + Type *SrcTy = Trunc.getSrcTy(); + Type *DestTy = Trunc.getType(); + if (isa<IntegerType>(SrcTy) && !ShouldChangeType(SrcTy, DestTy)) + return nullptr; + + BinaryOperator *LogicOp; + Constant *C; + if (!match(Trunc.getOperand(0), m_OneUse(m_BinOp(LogicOp))) || + !LogicOp->isBitwiseLogicOp() || + !match(LogicOp->getOperand(1), m_Constant(C))) + return nullptr; + + // trunc (logic X, C) --> logic (trunc X, C') + Constant *NarrowC = ConstantExpr::getTrunc(C, DestTy); + Value *NarrowOp0 = Builder->CreateTrunc(LogicOp->getOperand(0), DestTy); + return BinaryOperator::Create(LogicOp->getOpcode(), NarrowOp0, NarrowC); +} + Instruction *InstCombiner::visitTrunc(TruncInst &CI) { if (Instruction *Result = commonCastTransforms(CI)) return Result; // Test if the trunc is the user of a select which is part of a // minimum or maximum operation. If so, don't do any more simplification. - // Even simplifying demanded bits can break the canonical form of a + // Even simplifying demanded bits can break the canonical form of a // min/max. Value *LHS, *RHS; if (SelectInst *SI = dyn_cast<SelectInst>(CI.getOperand(0))) if (matchSelectPattern(SI, LHS, RHS).Flavor != SPF_UNKNOWN) return nullptr; - + // See if we can simplify any instructions used by the input whose sole // purpose is to compute bits we don't care about. if (SimplifyDemandedInstructionBits(CI)) @@ -562,14 +551,26 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { } } - // Transform "trunc (and X, cst)" -> "and (trunc X), cst" so long as the dest - // type isn't non-native. + if (Instruction *I = shrinkBitwiseLogic(CI)) + return I; + if (Src->hasOneUse() && isa<IntegerType>(SrcTy) && - ShouldChangeType(SrcTy, DestTy) && - match(Src, m_And(m_Value(A), m_ConstantInt(Cst)))) { - Value *NewTrunc = Builder->CreateTrunc(A, DestTy, A->getName() + ".tr"); - return BinaryOperator::CreateAnd(NewTrunc, - ConstantExpr::getTrunc(Cst, DestTy)); + ShouldChangeType(SrcTy, DestTy)) { + // Transform "trunc (shl X, cst)" -> "shl (trunc X), cst" so long as the + // dest type is native and cst < dest size. + if (match(Src, m_Shl(m_Value(A), m_ConstantInt(Cst))) && + !match(A, m_Shr(m_Value(), m_Constant()))) { + // Skip shifts of shift by constants. It undoes a combine in + // FoldShiftByConstant and is the extend in reg pattern. + const unsigned DestSize = DestTy->getScalarSizeInBits(); + if (Cst->getValue().ult(DestSize)) { + Value *NewTrunc = Builder->CreateTrunc(A, DestTy, A->getName() + ".tr"); + + return BinaryOperator::Create( + Instruction::Shl, NewTrunc, + ConstantInt::get(DestTy, Cst->getValue().trunc(DestSize))); + } + } } if (Instruction *I = foldVecTruncToExtElt(CI, *this, DL)) @@ -578,10 +579,8 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { return nullptr; } -/// Transform (zext icmp) to bitwise / integer operations in order to eliminate -/// the icmp. -Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, - bool DoXform) { +Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, + bool DoTransform) { // If we are just checking for a icmp eq of a single bit and zext'ing it // to an integer, then shift the bit to the appropriate place and then // cast to integer to avoid the comparison. @@ -592,7 +591,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, // zext (x >s -1) to i32 --> (x>>u31)^1 true if signbit clear. if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV == 0) || (ICI->getPredicate() == ICmpInst::ICMP_SGT && Op1CV.isAllOnesValue())) { - if (!DoXform) return ICI; + if (!DoTransform) return ICI; Value *In = ICI->getOperand(0); Value *Sh = ConstantInt::get(In->getType(), @@ -627,7 +626,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, APInt KnownZeroMask(~KnownZero); if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1? - if (!DoXform) return ICI; + if (!DoTransform) return ICI; bool isNE = ICI->getPredicate() == ICmpInst::ICMP_NE; if (Op1CV != 0 && (Op1CV != KnownZeroMask)) { @@ -655,7 +654,9 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, if (CI.getType() == In->getType()) return replaceInstUsesWith(CI, In); - return CastInst::CreateIntegerCast(In, CI.getType(), false/*ZExt*/); + + Value *IntCast = Builder->CreateIntCast(In, CI.getType(), false); + return replaceInstUsesWith(CI, IntCast); } } } @@ -678,7 +679,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, APInt KnownBits = KnownZeroLHS | KnownOneLHS; APInt UnknownBit = ~KnownBits; if (UnknownBit.countPopulation() == 1) { - if (!DoXform) return ICI; + if (!DoTransform) return ICI; Value *Result = Builder->CreateXor(LHS, RHS); @@ -760,9 +761,7 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, // If the operation is an AND/OR/XOR and the bits to clear are zero in the // other side, BitsToClear is ok. - if (Tmp == 0 && - (Opc == Instruction::And || Opc == Instruction::Or || - Opc == Instruction::Xor)) { + if (Tmp == 0 && I->isBitwiseLogicOp()) { // We use MaskedValueIsZero here for generality, but the case we care // about the most is constant RHS. unsigned VSize = V->getType()->getScalarSizeInBits(); @@ -922,16 +921,26 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { BinaryOperator *SrcI = dyn_cast<BinaryOperator>(Src); if (SrcI && SrcI->getOpcode() == Instruction::Or) { - // zext (or icmp, icmp) --> or (zext icmp), (zext icmp) if at least one - // of the (zext icmp) will be transformed. + // zext (or icmp, icmp) -> or (zext icmp), (zext icmp) if at least one + // of the (zext icmp) can be eliminated. If so, immediately perform the + // according elimination. ICmpInst *LHS = dyn_cast<ICmpInst>(SrcI->getOperand(0)); ICmpInst *RHS = dyn_cast<ICmpInst>(SrcI->getOperand(1)); if (LHS && RHS && LHS->hasOneUse() && RHS->hasOneUse() && (transformZExtICmp(LHS, CI, false) || transformZExtICmp(RHS, CI, false))) { + // zext (or icmp, icmp) -> or (zext icmp), (zext icmp) Value *LCast = Builder->CreateZExt(LHS, CI.getType(), LHS->getName()); Value *RCast = Builder->CreateZExt(RHS, CI.getType(), RHS->getName()); - return BinaryOperator::Create(Instruction::Or, LCast, RCast); + BinaryOperator *Or = BinaryOperator::Create(Instruction::Or, LCast, RCast); + + // Perform the elimination. + if (auto *LZExt = dyn_cast<ZExtInst>(LCast)) + transformZExtICmp(LHS, *LZExt); + if (auto *RZExt = dyn_cast<ZExtInst>(RCast)) + transformZExtICmp(RHS, *RZExt); + + return Or; } } @@ -952,14 +961,6 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { return BinaryOperator::CreateXor(Builder->CreateAnd(X, ZC), ZC); } - // zext (xor i1 X, true) to i32 --> xor (zext i1 X to i32), 1 - if (SrcI && SrcI->hasOneUse() && - SrcI->getType()->getScalarType()->isIntegerTy(1) && - match(SrcI, m_Not(m_Value(X))) && (!X->hasOneUse() || !isa<CmpInst>(X))) { - Value *New = Builder->CreateZExt(X, CI.getType()); - return BinaryOperator::CreateXor(New, ConstantInt::get(CI.getType(), 1)); - } - return nullptr; } @@ -1132,7 +1133,7 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { Type *SrcTy = Src->getType(), *DestTy = CI.getType(); // If we know that the value being extended is positive, we can use a zext - // instead. + // instead. bool KnownZero, KnownOne; ComputeSignBit(Src, KnownZero, KnownOne, 0, &CI); if (KnownZero) { @@ -1238,14 +1239,14 @@ static Value *lookThroughFPExtensions(Value *V) { if (CFP->getType() == Type::getPPC_FP128Ty(V->getContext())) return V; // No constant folding of this. // See if the value can be truncated to half and then reextended. - if (Value *V = fitsInFPType(CFP, APFloat::IEEEhalf)) + if (Value *V = fitsInFPType(CFP, APFloat::IEEEhalf())) return V; // See if the value can be truncated to float and then reextended. - if (Value *V = fitsInFPType(CFP, APFloat::IEEEsingle)) + if (Value *V = fitsInFPType(CFP, APFloat::IEEEsingle())) return V; if (CFP->getType()->isDoubleTy()) return V; // Won't shrink. - if (Value *V = fitsInFPType(CFP, APFloat::IEEEdouble)) + if (Value *V = fitsInFPType(CFP, APFloat::IEEEdouble())) return V; // Don't try to shrink to various long double types. } @@ -1789,6 +1790,205 @@ static Instruction *canonicalizeBitCastExtElt(BitCastInst &BitCast, return ExtractElementInst::Create(NewBC, ExtElt->getIndexOperand()); } +/// Change the type of a bitwise logic operation if we can eliminate a bitcast. +static Instruction *foldBitCastBitwiseLogic(BitCastInst &BitCast, + InstCombiner::BuilderTy &Builder) { + Type *DestTy = BitCast.getType(); + BinaryOperator *BO; + if (!DestTy->getScalarType()->isIntegerTy() || + !match(BitCast.getOperand(0), m_OneUse(m_BinOp(BO))) || + !BO->isBitwiseLogicOp()) + return nullptr; + + // FIXME: This transform is restricted to vector types to avoid backend + // problems caused by creating potentially illegal operations. If a fix-up is + // added to handle that situation, we can remove this check. + if (!DestTy->isVectorTy() || !BO->getType()->isVectorTy()) + return nullptr; + + Value *X; + if (match(BO->getOperand(0), m_OneUse(m_BitCast(m_Value(X)))) && + X->getType() == DestTy && !isa<Constant>(X)) { + // bitcast(logic(bitcast(X), Y)) --> logic'(X, bitcast(Y)) + Value *CastedOp1 = Builder.CreateBitCast(BO->getOperand(1), DestTy); + return BinaryOperator::Create(BO->getOpcode(), X, CastedOp1); + } + + if (match(BO->getOperand(1), m_OneUse(m_BitCast(m_Value(X)))) && + X->getType() == DestTy && !isa<Constant>(X)) { + // bitcast(logic(Y, bitcast(X))) --> logic'(bitcast(Y), X) + Value *CastedOp0 = Builder.CreateBitCast(BO->getOperand(0), DestTy); + return BinaryOperator::Create(BO->getOpcode(), CastedOp0, X); + } + + return nullptr; +} + +/// Change the type of a select if we can eliminate a bitcast. +static Instruction *foldBitCastSelect(BitCastInst &BitCast, + InstCombiner::BuilderTy &Builder) { + Value *Cond, *TVal, *FVal; + if (!match(BitCast.getOperand(0), + m_OneUse(m_Select(m_Value(Cond), m_Value(TVal), m_Value(FVal))))) + return nullptr; + + // A vector select must maintain the same number of elements in its operands. + Type *CondTy = Cond->getType(); + Type *DestTy = BitCast.getType(); + if (CondTy->isVectorTy()) { + if (!DestTy->isVectorTy()) + return nullptr; + if (DestTy->getVectorNumElements() != CondTy->getVectorNumElements()) + return nullptr; + } + + // FIXME: This transform is restricted from changing the select between + // scalars and vectors to avoid backend problems caused by creating + // potentially illegal operations. If a fix-up is added to handle that + // situation, we can remove this check. + if (DestTy->isVectorTy() != TVal->getType()->isVectorTy()) + return nullptr; + + auto *Sel = cast<Instruction>(BitCast.getOperand(0)); + Value *X; + if (match(TVal, m_OneUse(m_BitCast(m_Value(X)))) && X->getType() == DestTy && + !isa<Constant>(X)) { + // bitcast(select(Cond, bitcast(X), Y)) --> select'(Cond, X, bitcast(Y)) + Value *CastedVal = Builder.CreateBitCast(FVal, DestTy); + return SelectInst::Create(Cond, X, CastedVal, "", nullptr, Sel); + } + + if (match(FVal, m_OneUse(m_BitCast(m_Value(X)))) && X->getType() == DestTy && + !isa<Constant>(X)) { + // bitcast(select(Cond, Y, bitcast(X))) --> select'(Cond, bitcast(Y), X) + Value *CastedVal = Builder.CreateBitCast(TVal, DestTy); + return SelectInst::Create(Cond, CastedVal, X, "", nullptr, Sel); + } + + return nullptr; +} + +/// Check if all users of CI are StoreInsts. +static bool hasStoreUsersOnly(CastInst &CI) { + for (User *U : CI.users()) { + if (!isa<StoreInst>(U)) + return false; + } + return true; +} + +/// This function handles following case +/// +/// A -> B cast +/// PHI +/// B -> A cast +/// +/// All the related PHI nodes can be replaced by new PHI nodes with type A. +/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. +Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) { + // BitCast used by Store can be handled in InstCombineLoadStoreAlloca.cpp. + if (hasStoreUsersOnly(CI)) + return nullptr; + + Value *Src = CI.getOperand(0); + Type *SrcTy = Src->getType(); // Type B + Type *DestTy = CI.getType(); // Type A + + SmallVector<PHINode *, 4> PhiWorklist; + SmallSetVector<PHINode *, 4> OldPhiNodes; + + // Find all of the A->B casts and PHI nodes. + // We need to inpect all related PHI nodes, but PHIs can be cyclic, so + // OldPhiNodes is used to track all known PHI nodes, before adding a new + // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. + PhiWorklist.push_back(PN); + OldPhiNodes.insert(PN); + while (!PhiWorklist.empty()) { + auto *OldPN = PhiWorklist.pop_back_val(); + for (Value *IncValue : OldPN->incoming_values()) { + if (isa<Constant>(IncValue)) + continue; + + if (auto *LI = dyn_cast<LoadInst>(IncValue)) { + // If there is a sequence of one or more load instructions, each loaded + // value is used as address of later load instruction, bitcast is + // necessary to change the value type, don't optimize it. For + // simplicity we give up if the load address comes from another load. + Value *Addr = LI->getOperand(0); + if (Addr == &CI || isa<LoadInst>(Addr)) + return nullptr; + if (LI->hasOneUse() && LI->isSimple()) + continue; + // If a LoadInst has more than one use, changing the type of loaded + // value may create another bitcast. + return nullptr; + } + + if (auto *PNode = dyn_cast<PHINode>(IncValue)) { + if (OldPhiNodes.insert(PNode)) + PhiWorklist.push_back(PNode); + continue; + } + + auto *BCI = dyn_cast<BitCastInst>(IncValue); + // We can't handle other instructions. + if (!BCI) + return nullptr; + + // Verify it's a A->B cast. + Type *TyA = BCI->getOperand(0)->getType(); + Type *TyB = BCI->getType(); + if (TyA != DestTy || TyB != SrcTy) + return nullptr; + } + } + + // For each old PHI node, create a corresponding new PHI node with a type A. + SmallDenseMap<PHINode *, PHINode *> NewPNodes; + for (auto *OldPN : OldPhiNodes) { + Builder->SetInsertPoint(OldPN); + PHINode *NewPN = Builder->CreatePHI(DestTy, OldPN->getNumOperands()); + NewPNodes[OldPN] = NewPN; + } + + // Fill in the operands of new PHI nodes. + for (auto *OldPN : OldPhiNodes) { + PHINode *NewPN = NewPNodes[OldPN]; + for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) { + Value *V = OldPN->getOperand(j); + Value *NewV = nullptr; + if (auto *C = dyn_cast<Constant>(V)) { + NewV = ConstantExpr::getBitCast(C, DestTy); + } else if (auto *LI = dyn_cast<LoadInst>(V)) { + Builder->SetInsertPoint(LI->getNextNode()); + NewV = Builder->CreateBitCast(LI, DestTy); + Worklist.Add(LI); + } else if (auto *BCI = dyn_cast<BitCastInst>(V)) { + NewV = BCI->getOperand(0); + } else if (auto *PrevPN = dyn_cast<PHINode>(V)) { + NewV = NewPNodes[PrevPN]; + } + assert(NewV); + NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j)); + } + } + + // If there is a store with type B, change it to type A. + for (User *U : PN->users()) { + auto *SI = dyn_cast<StoreInst>(U); + if (SI && SI->isSimple() && SI->getOperand(0) == PN) { + Builder->SetInsertPoint(SI); + auto *NewBC = + cast<BitCastInst>(Builder->CreateBitCast(NewPNodes[PN], SrcTy)); + SI->setOperand(0, NewBC); + Worklist.Add(SI); + assert(hasStoreUsersOnly(*NewBC)); + } + } + + return replaceInstUsesWith(CI, NewPNodes[PN]); +} + Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // If the operands are integer typed then apply the integer transforms, // otherwise just apply the common ones. @@ -1912,9 +2112,20 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { } } + // Handle the A->B->A cast, and there is an intervening PHI node. + if (PHINode *PN = dyn_cast<PHINode>(Src)) + if (Instruction *I = optimizeBitCastFromPhi(CI, PN)) + return I; + if (Instruction *I = canonicalizeBitCastExtElt(CI, *this, DL)) return I; + if (Instruction *I = foldBitCastBitwiseLogic(CI, *Builder)) + return I; + + if (Instruction *I = foldBitCastSelect(CI, *Builder)) + return I; + if (SrcTy->isPointerTy()) return commonPointerCastTransforms(CI); return commonCastTransforms(CI); diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 961497fe3c2d..012bfc7b4944 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -35,17 +35,12 @@ using namespace PatternMatch; // How many times is a select replaced by one of its operands? STATISTIC(NumSel, "Number of select opts"); -// Initialization Routines -static ConstantInt *getOne(Constant *C) { - return ConstantInt::get(cast<IntegerType>(C->getType()), 1); -} - -static ConstantInt *ExtractElement(Constant *V, Constant *Idx) { +static ConstantInt *extractElement(Constant *V, Constant *Idx) { return cast<ConstantInt>(ConstantExpr::getExtractElement(V, Idx)); } -static bool HasAddOverflow(ConstantInt *Result, +static bool hasAddOverflow(ConstantInt *Result, ConstantInt *In1, ConstantInt *In2, bool IsSigned) { if (!IsSigned) @@ -58,28 +53,28 @@ static bool HasAddOverflow(ConstantInt *Result, /// Compute Result = In1+In2, returning true if the result overflowed for this /// type. -static bool AddWithOverflow(Constant *&Result, Constant *In1, +static bool addWithOverflow(Constant *&Result, Constant *In1, Constant *In2, bool IsSigned = false) { Result = ConstantExpr::getAdd(In1, In2); if (VectorType *VTy = dyn_cast<VectorType>(In1->getType())) { for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i); - if (HasAddOverflow(ExtractElement(Result, Idx), - ExtractElement(In1, Idx), - ExtractElement(In2, Idx), + if (hasAddOverflow(extractElement(Result, Idx), + extractElement(In1, Idx), + extractElement(In2, Idx), IsSigned)) return true; } return false; } - return HasAddOverflow(cast<ConstantInt>(Result), + return hasAddOverflow(cast<ConstantInt>(Result), cast<ConstantInt>(In1), cast<ConstantInt>(In2), IsSigned); } -static bool HasSubOverflow(ConstantInt *Result, +static bool hasSubOverflow(ConstantInt *Result, ConstantInt *In1, ConstantInt *In2, bool IsSigned) { if (!IsSigned) @@ -93,23 +88,23 @@ static bool HasSubOverflow(ConstantInt *Result, /// Compute Result = In1-In2, returning true if the result overflowed for this /// type. -static bool SubWithOverflow(Constant *&Result, Constant *In1, +static bool subWithOverflow(Constant *&Result, Constant *In1, Constant *In2, bool IsSigned = false) { Result = ConstantExpr::getSub(In1, In2); if (VectorType *VTy = dyn_cast<VectorType>(In1->getType())) { for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i); - if (HasSubOverflow(ExtractElement(Result, Idx), - ExtractElement(In1, Idx), - ExtractElement(In2, Idx), + if (hasSubOverflow(extractElement(Result, Idx), + extractElement(In1, Idx), + extractElement(In2, Idx), IsSigned)) return true; } return false; } - return HasSubOverflow(cast<ConstantInt>(Result), + return hasSubOverflow(cast<ConstantInt>(Result), cast<ConstantInt>(In1), cast<ConstantInt>(In2), IsSigned); } @@ -126,26 +121,26 @@ static bool isBranchOnSignBitCheck(ICmpInst &I, bool isSignBit) { /// Given an exploded icmp instruction, return true if the comparison only /// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if the /// result of the comparison is true when the input value is signed. -static bool isSignBitCheck(ICmpInst::Predicate Pred, ConstantInt *RHS, +static bool isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS, bool &TrueIfSigned) { switch (Pred) { case ICmpInst::ICMP_SLT: // True if LHS s< 0 TrueIfSigned = true; - return RHS->isZero(); + return RHS == 0; case ICmpInst::ICMP_SLE: // True if LHS s<= RHS and RHS == -1 TrueIfSigned = true; - return RHS->isAllOnesValue(); + return RHS.isAllOnesValue(); case ICmpInst::ICMP_SGT: // True if LHS s> -1 TrueIfSigned = false; - return RHS->isAllOnesValue(); + return RHS.isAllOnesValue(); case ICmpInst::ICMP_UGT: // True if LHS u> RHS and RHS == high-bit-mask - 1 TrueIfSigned = true; - return RHS->isMaxValue(true); + return RHS.isMaxSignedValue(); case ICmpInst::ICMP_UGE: // True if LHS u>= RHS and RHS == high-bit-mask (2^7, 2^15, 2^31, etc) TrueIfSigned = true; - return RHS->getValue().isSignBit(); + return RHS.isSignBit(); default: return false; } @@ -154,19 +149,20 @@ static bool isSignBitCheck(ICmpInst::Predicate Pred, ConstantInt *RHS, /// Returns true if the exploded icmp can be expressed as a signed comparison /// to zero and updates the predicate accordingly. /// The signedness of the comparison is preserved. -static bool isSignTest(ICmpInst::Predicate &Pred, const ConstantInt *RHS) { +/// TODO: Refactor with decomposeBitTestICmp()? +static bool isSignTest(ICmpInst::Predicate &Pred, const APInt &C) { if (!ICmpInst::isSigned(Pred)) return false; - if (RHS->isZero()) + if (C == 0) return ICmpInst::isRelational(Pred); - if (RHS->isOne()) { + if (C == 1) { if (Pred == ICmpInst::ICMP_SLT) { Pred = ICmpInst::ICMP_SLE; return true; } - } else if (RHS->isAllOnesValue()) { + } else if (C.isAllOnesValue()) { if (Pred == ICmpInst::ICMP_SGT) { Pred = ICmpInst::ICMP_SGE; return true; @@ -176,16 +172,10 @@ static bool isSignTest(ICmpInst::Predicate &Pred, const ConstantInt *RHS) { return false; } -/// Return true if the constant is of the form 1+0+. This is the same as -/// lowones(~X). -static bool isHighOnes(const ConstantInt *CI) { - return (~CI->getValue() + 1).isPowerOf2(); -} - /// Given a signed integer type and a set of known zero and one bits, compute /// the maximum and minimum values that could have the specified known zero and /// known one bits, returning them in Min/Max. -static void ComputeSignedMinMaxValuesFromKnownBits(const APInt &KnownZero, +static void computeSignedMinMaxValuesFromKnownBits(const APInt &KnownZero, const APInt &KnownOne, APInt &Min, APInt &Max) { assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() && @@ -208,7 +198,7 @@ static void ComputeSignedMinMaxValuesFromKnownBits(const APInt &KnownZero, /// Given an unsigned integer type and a set of known zero and one bits, compute /// the maximum and minimum values that could have the specified known zero and /// known one bits, returning them in Min/Max. -static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero, +static void computeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero, const APInt &KnownOne, APInt &Min, APInt &Max) { assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() && @@ -231,9 +221,10 @@ static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero, /// /// If AndCst is non-null, then the loaded value is masked with that constant /// before doing the comparison. This handles cases like "A[i]&4 == 0". -Instruction *InstCombiner:: -FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, - CmpInst &ICI, ConstantInt *AndCst) { +Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, + GlobalVariable *GV, + CmpInst &ICI, + ConstantInt *AndCst) { Constant *Init = GV->getInitializer(); if (!isa<ConstantArray>(Init) && !isa<ConstantDataArray>(Init)) return nullptr; @@ -319,7 +310,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // Find out if the comparison would be true or false for the i'th element. Constant *C = ConstantFoldCompareInstOperands(ICI.getPredicate(), Elt, - CompareRHS, DL, TLI); + CompareRHS, DL, &TLI); // If the result is undef for this element, ignore it. if (isa<UndefValue>(C)) { // Extend range state machines to cover this element in case there is an @@ -509,7 +500,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, /// /// If we can't emit an optimized form for this expression, this returns null. /// -static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, +static Value *evaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, const DataLayout &DL) { gep_type_iterator GTI = gep_type_begin(GEP); @@ -526,7 +517,7 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, if (CI->isZero()) continue; // Handle a struct index, which adds its field offset to the pointer. - if (StructType *STy = dyn_cast<StructType>(*GTI)) { + if (StructType *STy = GTI.getStructTypeOrNull()) { Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); } else { uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); @@ -556,7 +547,7 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, if (CI->isZero()) continue; // Handle a struct index, which adds its field offset to the pointer. - if (StructType *STy = dyn_cast<StructType>(*GTI)) { + if (StructType *STy = GTI.getStructTypeOrNull()) { Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); } else { uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); @@ -919,7 +910,7 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, /// Fold comparisons between a GEP instruction and something else. At this point /// we know that the GEP is on the LHS of the comparison. -Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, +Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, ICmpInst::Predicate Cond, Instruction &I) { // Don't transform signed compares of GEPs into index compares. Even if the @@ -941,7 +932,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // This transformation (ignoring the base and scales) is valid because we // know pointers can't overflow since the gep is inbounds. See if we can // output an optimized form. - Value *Offset = EvaluateGEPOffsetExpression(GEPLHS, *this, DL); + Value *Offset = evaluateGEPOffsetExpression(GEPLHS, *this, DL); // If not, synthesize the offset the hard way. if (!Offset) @@ -1003,12 +994,12 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // If one of the GEPs has all zero indices, recurse. if (GEPLHS->hasAllZeroIndices()) - return FoldGEPICmp(GEPRHS, GEPLHS->getOperand(0), + return foldGEPICmp(GEPRHS, GEPLHS->getOperand(0), ICmpInst::getSwappedPredicate(Cond), I); // If the other GEP has all zero indices, recurse. if (GEPRHS->hasAllZeroIndices()) - return FoldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I); + return foldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I); bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds(); if (GEPLHS->getNumOperands() == GEPRHS->getNumOperands()) { @@ -1056,8 +1047,9 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, return transformToIndexedCompare(GEPLHS, RHS, Cond, DL); } -Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, - Value *Other) { +Instruction *InstCombiner::foldAllocaCmp(ICmpInst &ICI, + const AllocaInst *Alloca, + const Value *Other) { assert(ICI.isEquality() && "Cannot fold non-equality comparison."); // It would be tempting to fold away comparisons between allocas and any @@ -1076,8 +1068,8 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, unsigned MaxIter = 32; // Break cycles and bound to constant-time. - SmallVector<Use *, 32> Worklist; - for (Use &U : Alloca->uses()) { + SmallVector<const Use *, 32> Worklist; + for (const Use &U : Alloca->uses()) { if (Worklist.size() >= MaxIter) return nullptr; Worklist.push_back(&U); @@ -1086,8 +1078,8 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, unsigned NumCmps = 0; while (!Worklist.empty()) { assert(Worklist.size() <= MaxIter); - Use *U = Worklist.pop_back_val(); - Value *V = U->getUser(); + const Use *U = Worklist.pop_back_val(); + const Value *V = U->getUser(); --MaxIter; if (isa<BitCastInst>(V) || isa<GetElementPtrInst>(V) || isa<PHINode>(V) || @@ -1096,7 +1088,7 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, } else if (isa<LoadInst>(V)) { // Loading from the pointer doesn't escape it. continue; - } else if (auto *SI = dyn_cast<StoreInst>(V)) { + } else if (const auto *SI = dyn_cast<StoreInst>(V)) { // Storing *to* the pointer is fine, but storing the pointer escapes it. if (SI->getValueOperand() == U->get()) return nullptr; @@ -1105,7 +1097,7 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, if (NumCmps++) return nullptr; // Found more than one cmp. continue; - } else if (auto *Intrin = dyn_cast<IntrinsicInst>(V)) { + } else if (const auto *Intrin = dyn_cast<IntrinsicInst>(V)) { switch (Intrin->getIntrinsicID()) { // These intrinsics don't escape or compare the pointer. Memset is safe // because we don't allow ptrtoint. Memcpy and memmove are safe because @@ -1120,7 +1112,7 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, } else { return nullptr; } - for (Use &U : V->uses()) { + for (const Use &U : V->uses()) { if (Worklist.size() >= MaxIter) return nullptr; Worklist.push_back(&U); @@ -1134,9 +1126,9 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, } /// Fold "icmp pred (X+CI), X". -Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI, - Value *X, ConstantInt *CI, - ICmpInst::Predicate Pred) { +Instruction *InstCombiner::foldICmpAddOpConst(Instruction &ICI, + Value *X, ConstantInt *CI, + ICmpInst::Predicate Pred) { // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0, // so the values can never be equal. Similarly for all other "or equals" // operators. @@ -1181,52 +1173,995 @@ Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI, return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantExpr::getSub(SMax, C)); } -/// Fold "icmp pred, ([su]div X, DivRHS), CmpRHS" where DivRHS and CmpRHS are -/// both known to be integer constants. -Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, - ConstantInt *DivRHS) { - ConstantInt *CmpRHS = cast<ConstantInt>(ICI.getOperand(1)); - const APInt &CmpRHSV = CmpRHS->getValue(); +/// Handle "(icmp eq/ne (ashr/lshr AP2, A), AP1)" -> +/// (icmp eq/ne A, Log2(AP2/AP1)) -> +/// (icmp eq/ne A, Log2(AP2) - Log2(AP1)). +Instruction *InstCombiner::foldICmpShrConstConst(ICmpInst &I, Value *A, + const APInt &AP1, + const APInt &AP2) { + assert(I.isEquality() && "Cannot fold icmp gt/lt"); + + auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (I.getPredicate() == I.ICMP_NE) + Pred = CmpInst::getInversePredicate(Pred); + return new ICmpInst(Pred, LHS, RHS); + }; + + // Don't bother doing any work for cases which InstSimplify handles. + if (AP2 == 0) + return nullptr; + + bool IsAShr = isa<AShrOperator>(I.getOperand(0)); + if (IsAShr) { + if (AP2.isAllOnesValue()) + return nullptr; + if (AP2.isNegative() != AP1.isNegative()) + return nullptr; + if (AP2.sgt(AP1)) + return nullptr; + } + + if (!AP1) + // 'A' must be large enough to shift out the highest set bit. + return getICmp(I.ICMP_UGT, A, + ConstantInt::get(A->getType(), AP2.logBase2())); + + if (AP1 == AP2) + return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + + int Shift; + if (IsAShr && AP1.isNegative()) + Shift = AP1.countLeadingOnes() - AP2.countLeadingOnes(); + else + Shift = AP1.countLeadingZeros() - AP2.countLeadingZeros(); + + if (Shift > 0) { + if (IsAShr && AP1 == AP2.ashr(Shift)) { + // There are multiple solutions if we are comparing against -1 and the LHS + // of the ashr is not a power of two. + if (AP1.isAllOnesValue() && !AP2.isPowerOf2()) + return getICmp(I.ICMP_UGE, A, ConstantInt::get(A->getType(), Shift)); + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + } else if (AP1 == AP2.lshr(Shift)) { + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + } + } + + // Shifting const2 will never be equal to const1. + // FIXME: This should always be handled by InstSimplify? + auto *TorF = ConstantInt::get(I.getType(), I.getPredicate() == I.ICMP_NE); + return replaceInstUsesWith(I, TorF); +} + +/// Handle "(icmp eq/ne (shl AP2, A), AP1)" -> +/// (icmp eq/ne A, TrailingZeros(AP1) - TrailingZeros(AP2)). +Instruction *InstCombiner::foldICmpShlConstConst(ICmpInst &I, Value *A, + const APInt &AP1, + const APInt &AP2) { + assert(I.isEquality() && "Cannot fold icmp gt/lt"); + + auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (I.getPredicate() == I.ICMP_NE) + Pred = CmpInst::getInversePredicate(Pred); + return new ICmpInst(Pred, LHS, RHS); + }; + + // Don't bother doing any work for cases which InstSimplify handles. + if (AP2 == 0) + return nullptr; + + unsigned AP2TrailingZeros = AP2.countTrailingZeros(); + + if (!AP1 && AP2TrailingZeros != 0) + return getICmp( + I.ICMP_UGE, A, + ConstantInt::get(A->getType(), AP2.getBitWidth() - AP2TrailingZeros)); + + if (AP1 == AP2) + return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + + // Get the distance between the lowest bits that are set. + int Shift = AP1.countTrailingZeros() - AP2TrailingZeros; + + if (Shift > 0 && AP2.shl(Shift) == AP1) + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + + // Shifting const2 will never be equal to const1. + // FIXME: This should always be handled by InstSimplify? + auto *TorF = ConstantInt::get(I.getType(), I.getPredicate() == I.ICMP_NE); + return replaceInstUsesWith(I, TorF); +} + +/// The caller has matched a pattern of the form: +/// I = icmp ugt (add (add A, B), CI2), CI1 +/// If this is of the form: +/// sum = a + b +/// if (sum+128 >u 255) +/// Then replace it with llvm.sadd.with.overflow.i8. +/// +static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, + ConstantInt *CI2, ConstantInt *CI1, + InstCombiner &IC) { + // The transformation we're trying to do here is to transform this into an + // llvm.sadd.with.overflow. To do this, we have to replace the original add + // with a narrower add, and discard the add-with-constant that is part of the + // range check (if we can't eliminate it, this isn't profitable). + + // In order to eliminate the add-with-constant, the compare can be its only + // use. + Instruction *AddWithCst = cast<Instruction>(I.getOperand(0)); + if (!AddWithCst->hasOneUse()) + return nullptr; + + // If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow. + if (!CI2->getValue().isPowerOf2()) + return nullptr; + unsigned NewWidth = CI2->getValue().countTrailingZeros(); + if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) + return nullptr; + + // The width of the new add formed is 1 more than the bias. + ++NewWidth; + + // Check to see that CI1 is an all-ones value with NewWidth bits. + if (CI1->getBitWidth() == NewWidth || + CI1->getValue() != APInt::getLowBitsSet(CI1->getBitWidth(), NewWidth)) + return nullptr; + + // This is only really a signed overflow check if the inputs have been + // sign-extended; check for that condition. For example, if CI2 is 2^31 and + // the operands of the add are 64 bits wide, we need at least 33 sign bits. + unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1; + if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits || + IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits) + return nullptr; + + // In order to replace the original add with a narrower + // llvm.sadd.with.overflow, the only uses allowed are the add-with-constant + // and truncates that discard the high bits of the add. Verify that this is + // the case. + Instruction *OrigAdd = cast<Instruction>(AddWithCst->getOperand(0)); + for (User *U : OrigAdd->users()) { + if (U == AddWithCst) + continue; + + // Only accept truncates for now. We would really like a nice recursive + // predicate like SimplifyDemandedBits, but which goes downwards the use-def + // chain to see which bits of a value are actually demanded. If the + // original add had another add which was then immediately truncated, we + // could still do the transformation. + TruncInst *TI = dyn_cast<TruncInst>(U); + if (!TI || TI->getType()->getPrimitiveSizeInBits() > NewWidth) + return nullptr; + } + + // If the pattern matches, truncate the inputs to the narrower type and + // use the sadd_with_overflow intrinsic to efficiently compute both the + // result and the overflow bit. + Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth); + Value *F = Intrinsic::getDeclaration(I.getModule(), + Intrinsic::sadd_with_overflow, NewType); + + InstCombiner::BuilderTy *Builder = IC.Builder; + + // Put the new code above the original add, in case there are any uses of the + // add between the add and the compare. + Builder->SetInsertPoint(OrigAdd); + + Value *TruncA = Builder->CreateTrunc(A, NewType, A->getName() + ".trunc"); + Value *TruncB = Builder->CreateTrunc(B, NewType, B->getName() + ".trunc"); + CallInst *Call = Builder->CreateCall(F, {TruncA, TruncB}, "sadd"); + Value *Add = Builder->CreateExtractValue(Call, 0, "sadd.result"); + Value *ZExt = Builder->CreateZExt(Add, OrigAdd->getType()); + + // The inner add was the result of the narrow add, zero extended to the + // wider type. Replace it with the result computed by the intrinsic. + IC.replaceInstUsesWith(*OrigAdd, ZExt); + + // The original icmp gets replaced with the overflow value. + return ExtractValueInst::Create(Call, 1, "sadd.overflow"); +} + +// Fold icmp Pred X, C. +Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { + CmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = Cmp.getOperand(0); + + const APInt *C; + if (!match(Cmp.getOperand(1), m_APInt(C))) + return nullptr; + + Value *A = nullptr, *B = nullptr; + + // Match the following pattern, which is a common idiom when writing + // overflow-safe integer arithmetic functions. The source performs an addition + // in wider type and explicitly checks for overflow using comparisons against + // INT_MIN and INT_MAX. Simplify by using the sadd_with_overflow intrinsic. + // + // TODO: This could probably be generalized to handle other overflow-safe + // operations if we worked out the formulas to compute the appropriate magic + // constants. + // + // sum = a + b + // if (sum+128 >u 255) ... -> llvm.sadd.with.overflow.i8 + { + ConstantInt *CI2; // I = icmp ugt (add (add A, B), CI2), CI + if (Pred == ICmpInst::ICMP_UGT && + match(X, m_Add(m_Add(m_Value(A), m_Value(B)), m_ConstantInt(CI2)))) + if (Instruction *Res = processUGT_ADDCST_ADD( + Cmp, A, B, CI2, cast<ConstantInt>(Cmp.getOperand(1)), *this)) + return Res; + } + + // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) + if (*C == 0 && Pred == ICmpInst::ICMP_SGT) { + SelectPatternResult SPR = matchSelectPattern(X, A, B); + if (SPR.Flavor == SPF_SMIN) { + if (isKnownPositive(A, DL)) + return new ICmpInst(Pred, B, Cmp.getOperand(1)); + if (isKnownPositive(B, DL)) + return new ICmpInst(Pred, A, Cmp.getOperand(1)); + } + } + + // FIXME: Use m_APInt to allow folds for splat constants. + ConstantInt *CI = dyn_cast<ConstantInt>(Cmp.getOperand(1)); + if (!CI) + return nullptr; + + // Canonicalize icmp instructions based on dominating conditions. + BasicBlock *Parent = Cmp.getParent(); + BasicBlock *Dom = Parent->getSinglePredecessor(); + auto *BI = Dom ? dyn_cast<BranchInst>(Dom->getTerminator()) : nullptr; + ICmpInst::Predicate Pred2; + BasicBlock *TrueBB, *FalseBB; + ConstantInt *CI2; + if (BI && match(BI, m_Br(m_ICmp(Pred2, m_Specific(X), m_ConstantInt(CI2)), + TrueBB, FalseBB)) && + TrueBB != FalseBB) { + ConstantRange CR = + ConstantRange::makeAllowedICmpRegion(Pred, CI->getValue()); + ConstantRange DominatingCR = + (Parent == TrueBB) + ? ConstantRange::makeExactICmpRegion(Pred2, CI2->getValue()) + : ConstantRange::makeExactICmpRegion( + CmpInst::getInversePredicate(Pred2), CI2->getValue()); + ConstantRange Intersection = DominatingCR.intersectWith(CR); + ConstantRange Difference = DominatingCR.difference(CR); + if (Intersection.isEmptySet()) + return replaceInstUsesWith(Cmp, Builder->getFalse()); + if (Difference.isEmptySet()) + return replaceInstUsesWith(Cmp, Builder->getTrue()); + + // If this is a normal comparison, it demands all bits. If it is a sign + // bit comparison, it only demands the sign bit. + bool UnusedBit; + bool IsSignBit = isSignBitCheck(Pred, CI->getValue(), UnusedBit); + + // Canonicalizing a sign bit comparison that gets used in a branch, + // pessimizes codegen by generating branch on zero instruction instead + // of a test and branch. So we avoid canonicalizing in such situations + // because test and branch instruction has better branch displacement + // than compare and branch instruction. + if (!isBranchOnSignBitCheck(Cmp, IsSignBit) && !Cmp.isEquality()) { + if (auto *AI = Intersection.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_EQ, X, Builder->getInt(*AI)); + if (auto *AD = Difference.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_NE, X, Builder->getInt(*AD)); + } + } + + return nullptr; +} + +/// Fold icmp (trunc X, Y), C. +Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp, + Instruction *Trunc, + const APInt *C) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = Trunc->getOperand(0); + if (*C == 1 && C->getBitWidth() > 1) { + // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 + Value *V = nullptr; + if (Pred == ICmpInst::ICMP_SLT && match(X, m_Signum(m_Value(V)))) + return new ICmpInst(ICmpInst::ICMP_SLT, V, + ConstantInt::get(V->getType(), 1)); + } + + if (Cmp.isEquality() && Trunc->hasOneUse()) { + // Simplify icmp eq (trunc x to i8), 42 -> icmp eq x, 42|highbits if all + // of the high bits truncated out of x are known. + unsigned DstBits = Trunc->getType()->getScalarSizeInBits(), + SrcBits = X->getType()->getScalarSizeInBits(); + APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0); + computeKnownBits(X, KnownZero, KnownOne, 0, &Cmp); + + // If all the high bits are known, we can do this xform. + if ((KnownZero | KnownOne).countLeadingOnes() >= SrcBits - DstBits) { + // Pull in the high bits from known-ones set. + APInt NewRHS = C->zext(SrcBits); + NewRHS |= KnownOne & APInt::getHighBitsSet(SrcBits, SrcBits - DstBits); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), NewRHS)); + } + } + + return nullptr; +} + +/// Fold icmp (xor X, Y), C. +Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, + BinaryOperator *Xor, + const APInt *C) { + Value *X = Xor->getOperand(0); + Value *Y = Xor->getOperand(1); + const APInt *XorC; + if (!match(Y, m_APInt(XorC))) + return nullptr; + + // If this is a comparison that tests the signbit (X < 0) or (x > -1), + // fold the xor. + ICmpInst::Predicate Pred = Cmp.getPredicate(); + if ((Pred == ICmpInst::ICMP_SLT && *C == 0) || + (Pred == ICmpInst::ICMP_SGT && C->isAllOnesValue())) { + + // If the sign bit of the XorCst is not set, there is no change to + // the operation, just stop using the Xor. + if (!XorC->isNegative()) { + Cmp.setOperand(0, X); + Worklist.Add(Xor); + return &Cmp; + } + + // Was the old condition true if the operand is positive? + bool isTrueIfPositive = Pred == ICmpInst::ICMP_SGT; + + // If so, the new one isn't. + isTrueIfPositive ^= true; + + Constant *CmpConstant = cast<Constant>(Cmp.getOperand(1)); + if (isTrueIfPositive) + return new ICmpInst(ICmpInst::ICMP_SGT, X, SubOne(CmpConstant)); + else + return new ICmpInst(ICmpInst::ICMP_SLT, X, AddOne(CmpConstant)); + } + + if (Xor->hasOneUse()) { + // (icmp u/s (xor X SignBit), C) -> (icmp s/u X, (xor C SignBit)) + if (!Cmp.isEquality() && XorC->isSignBit()) { + Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate() + : Cmp.getSignedPredicate(); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), *C ^ *XorC)); + } + + // (icmp u/s (xor X ~SignBit), C) -> (icmp s/u X, (xor C ~SignBit)) + if (!Cmp.isEquality() && XorC->isMaxSignedValue()) { + Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate() + : Cmp.getSignedPredicate(); + Pred = Cmp.getSwappedPredicate(Pred); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), *C ^ *XorC)); + } + } + + // (icmp ugt (xor X, C), ~C) -> (icmp ult X, C) + // iff -C is a power of 2 + if (Pred == ICmpInst::ICMP_UGT && *XorC == ~(*C) && (*C + 1).isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_ULT, X, Y); + + // (icmp ult (xor X, C), -C) -> (icmp uge X, C) + // iff -C is a power of 2 + if (Pred == ICmpInst::ICMP_ULT && *XorC == -(*C) && C->isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_UGE, X, Y); + + return nullptr; +} + +/// Fold icmp (and (sh X, Y), C2), C1. +Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, + const APInt *C1, const APInt *C2) { + BinaryOperator *Shift = dyn_cast<BinaryOperator>(And->getOperand(0)); + if (!Shift || !Shift->isShift()) + return nullptr; + + // If this is: (X >> C3) & C2 != C1 (where any shift and any compare could + // exist), turn it into (X & (C2 << C3)) != (C1 << C3). This happens a LOT in + // code produced by the clang front-end, for bitfield access. + // This seemingly simple opportunity to fold away a shift turns out to be + // rather complicated. See PR17827 for details. + unsigned ShiftOpcode = Shift->getOpcode(); + bool IsShl = ShiftOpcode == Instruction::Shl; + const APInt *C3; + if (match(Shift->getOperand(1), m_APInt(C3))) { + bool CanFold = false; + if (ShiftOpcode == Instruction::AShr) { + // There may be some constraints that make this possible, but nothing + // simple has been discovered yet. + CanFold = false; + } else if (ShiftOpcode == Instruction::Shl) { + // For a left shift, we can fold if the comparison is not signed. We can + // also fold a signed comparison if the mask value and comparison value + // are not negative. These constraints may not be obvious, but we can + // prove that they are correct using an SMT solver. + if (!Cmp.isSigned() || (!C2->isNegative() && !C1->isNegative())) + CanFold = true; + } else if (ShiftOpcode == Instruction::LShr) { + // For a logical right shift, we can fold if the comparison is not signed. + // We can also fold a signed comparison if the shifted mask value and the + // shifted comparison value are not negative. These constraints may not be + // obvious, but we can prove that they are correct using an SMT solver. + if (!Cmp.isSigned() || + (!C2->shl(*C3).isNegative() && !C1->shl(*C3).isNegative())) + CanFold = true; + } + + if (CanFold) { + APInt NewCst = IsShl ? C1->lshr(*C3) : C1->shl(*C3); + APInt SameAsC1 = IsShl ? NewCst.shl(*C3) : NewCst.lshr(*C3); + // Check to see if we are shifting out any of the bits being compared. + if (SameAsC1 != *C1) { + // If we shifted bits out, the fold is not going to work out. As a + // special case, check to see if this means that the result is always + // true or false now. + if (Cmp.getPredicate() == ICmpInst::ICMP_EQ) + return replaceInstUsesWith(Cmp, ConstantInt::getFalse(Cmp.getType())); + if (Cmp.getPredicate() == ICmpInst::ICMP_NE) + return replaceInstUsesWith(Cmp, ConstantInt::getTrue(Cmp.getType())); + } else { + Cmp.setOperand(1, ConstantInt::get(And->getType(), NewCst)); + APInt NewAndCst = IsShl ? C2->lshr(*C3) : C2->shl(*C3); + And->setOperand(1, ConstantInt::get(And->getType(), NewAndCst)); + And->setOperand(0, Shift->getOperand(0)); + Worklist.Add(Shift); // Shift is dead. + return &Cmp; + } + } + } + + // Turn ((X >> Y) & C2) == 0 into (X & (C2 << Y)) == 0. The latter is + // preferable because it allows the C2 << Y expression to be hoisted out of a + // loop if Y is invariant and X is not. + if (Shift->hasOneUse() && *C1 == 0 && Cmp.isEquality() && + !Shift->isArithmeticShift() && !isa<Constant>(Shift->getOperand(0))) { + // Compute C2 << Y. + Value *NewShift = + IsShl ? Builder->CreateLShr(And->getOperand(1), Shift->getOperand(1)) + : Builder->CreateShl(And->getOperand(1), Shift->getOperand(1)); + + // Compute X & (C2 << Y). + Value *NewAnd = Builder->CreateAnd(Shift->getOperand(0), NewShift); + Cmp.setOperand(0, NewAnd); + return &Cmp; + } + + return nullptr; +} + +/// Fold icmp (and X, C2), C1. +Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, + BinaryOperator *And, + const APInt *C1) { + const APInt *C2; + if (!match(And->getOperand(1), m_APInt(C2))) + return nullptr; + + if (!And->hasOneUse() || !And->getOperand(0)->hasOneUse()) + return nullptr; + + // If the LHS is an 'and' of a truncate and we can widen the and/compare to + // the input width without changing the value produced, eliminate the cast: + // + // icmp (and (trunc W), C2), C1 -> icmp (and W, C2'), C1' + // + // We can do this transformation if the constants do not have their sign bits + // set or if it is an equality comparison. Extending a relational comparison + // when we're checking the sign bit would not work. + Value *W; + if (match(And->getOperand(0), m_Trunc(m_Value(W))) && + (Cmp.isEquality() || (!C1->isNegative() && !C2->isNegative()))) { + // TODO: Is this a good transform for vectors? Wider types may reduce + // throughput. Should this transform be limited (even for scalars) by using + // ShouldChangeType()? + if (!Cmp.getType()->isVectorTy()) { + Type *WideType = W->getType(); + unsigned WideScalarBits = WideType->getScalarSizeInBits(); + Constant *ZextC1 = ConstantInt::get(WideType, C1->zext(WideScalarBits)); + Constant *ZextC2 = ConstantInt::get(WideType, C2->zext(WideScalarBits)); + Value *NewAnd = Builder->CreateAnd(W, ZextC2, And->getName()); + return new ICmpInst(Cmp.getPredicate(), NewAnd, ZextC1); + } + } + + if (Instruction *I = foldICmpAndShift(Cmp, And, C1, C2)) + return I; + + // (icmp pred (and (or (lshr A, B), A), 1), 0) --> + // (icmp pred (and A, (or (shl 1, B), 1), 0)) + // + // iff pred isn't signed + if (!Cmp.isSigned() && *C1 == 0 && match(And->getOperand(1), m_One())) { + Constant *One = cast<Constant>(And->getOperand(1)); + Value *Or = And->getOperand(0); + Value *A, *B, *LShr; + if (match(Or, m_Or(m_Value(LShr), m_Value(A))) && + match(LShr, m_LShr(m_Specific(A), m_Value(B)))) { + unsigned UsesRemoved = 0; + if (And->hasOneUse()) + ++UsesRemoved; + if (Or->hasOneUse()) + ++UsesRemoved; + if (LShr->hasOneUse()) + ++UsesRemoved; + + // Compute A & ((1 << B) | 1) + Value *NewOr = nullptr; + if (auto *C = dyn_cast<Constant>(B)) { + if (UsesRemoved >= 1) + NewOr = ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One); + } else { + if (UsesRemoved >= 3) + NewOr = Builder->CreateOr(Builder->CreateShl(One, B, LShr->getName(), + /*HasNUW=*/true), + One, Or->getName()); + } + if (NewOr) { + Value *NewAnd = Builder->CreateAnd(A, NewOr, And->getName()); + Cmp.setOperand(0, NewAnd); + return &Cmp; + } + } + } + + // (X & C2) > C1 --> (X & C2) != 0, if any bit set in (X & C2) will produce a + // result greater than C1. + unsigned NumTZ = C2->countTrailingZeros(); + if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && NumTZ < C2->getBitWidth() && + APInt::getOneBitSet(C2->getBitWidth(), NumTZ).ugt(*C1)) { + Constant *Zero = Constant::getNullValue(And->getType()); + return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); + } + + return nullptr; +} + +/// Fold icmp (and X, Y), C. +Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, + BinaryOperator *And, + const APInt *C) { + if (Instruction *I = foldICmpAndConstConst(Cmp, And, C)) + return I; + + // TODO: These all require that Y is constant too, so refactor with the above. + + // Try to optimize things like "A[i] & 42 == 0" to index computations. + Value *X = And->getOperand(0); + Value *Y = And->getOperand(1); + if (auto *LI = dyn_cast<LoadInst>(X)) + if (auto *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0))) + if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) + if (GV->isConstant() && GV->hasDefinitiveInitializer() && + !LI->isVolatile() && isa<ConstantInt>(Y)) { + ConstantInt *C2 = cast<ConstantInt>(Y); + if (Instruction *Res = foldCmpLoadFromIndexedGlobal(GEP, GV, Cmp, C2)) + return Res; + } + + if (!Cmp.isEquality()) + return nullptr; + + // X & -C == -C -> X > u ~C + // X & -C != -C -> X <= u ~C + // iff C is a power of 2 + if (Cmp.getOperand(1) == Y && (-(*C)).isPowerOf2()) { + auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGT + : CmpInst::ICMP_ULE; + return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1)))); + } + + // (X & C2) == 0 -> (trunc X) >= 0 + // (X & C2) != 0 -> (trunc X) < 0 + // iff C2 is a power of 2 and it masks the sign bit of a legal integer type. + const APInt *C2; + if (And->hasOneUse() && *C == 0 && match(Y, m_APInt(C2))) { + int32_t ExactLogBase2 = C2->exactLogBase2(); + if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) { + Type *NTy = IntegerType::get(Cmp.getContext(), ExactLogBase2 + 1); + if (And->getType()->isVectorTy()) + NTy = VectorType::get(NTy, And->getType()->getVectorNumElements()); + Value *Trunc = Builder->CreateTrunc(X, NTy); + auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_SGE + : CmpInst::ICMP_SLT; + return new ICmpInst(NewPred, Trunc, Constant::getNullValue(NTy)); + } + } + + return nullptr; +} + +/// Fold icmp (or X, Y), C. +Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, + const APInt *C) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (*C == 1) { + // icmp slt signum(V) 1 --> icmp slt V, 1 + Value *V = nullptr; + if (Pred == ICmpInst::ICMP_SLT && match(Or, m_Signum(m_Value(V)))) + return new ICmpInst(ICmpInst::ICMP_SLT, V, + ConstantInt::get(V->getType(), 1)); + } + + if (!Cmp.isEquality() || *C != 0 || !Or->hasOneUse()) + return nullptr; + + Value *P, *Q; + if (match(Or, m_Or(m_PtrToInt(m_Value(P)), m_PtrToInt(m_Value(Q))))) { + // Simplify icmp eq (or (ptrtoint P), (ptrtoint Q)), 0 + // -> and (icmp eq P, null), (icmp eq Q, null). + Value *CmpP = + Builder->CreateICmp(Pred, P, ConstantInt::getNullValue(P->getType())); + Value *CmpQ = + Builder->CreateICmp(Pred, Q, ConstantInt::getNullValue(Q->getType())); + auto LogicOpc = Pred == ICmpInst::Predicate::ICMP_EQ ? Instruction::And + : Instruction::Or; + return BinaryOperator::Create(LogicOpc, CmpP, CmpQ); + } + + return nullptr; +} + +/// Fold icmp (mul X, Y), C. +Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp, + BinaryOperator *Mul, + const APInt *C) { + const APInt *MulC; + if (!match(Mul->getOperand(1), m_APInt(MulC))) + return nullptr; + + // If this is a test of the sign bit and the multiply is sign-preserving with + // a constant operand, use the multiply LHS operand instead. + ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (isSignTest(Pred, *C) && Mul->hasNoSignedWrap()) { + if (MulC->isNegative()) + Pred = ICmpInst::getSwappedPredicate(Pred); + return new ICmpInst(Pred, Mul->getOperand(0), + Constant::getNullValue(Mul->getType())); + } + + return nullptr; +} + +/// Fold icmp (shl 1, Y), C. +static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, + const APInt *C) { + Value *Y; + if (!match(Shl, m_Shl(m_One(), m_Value(Y)))) + return nullptr; + + Type *ShiftType = Shl->getType(); + uint32_t TypeBits = C->getBitWidth(); + bool CIsPowerOf2 = C->isPowerOf2(); + ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (Cmp.isUnsigned()) { + // (1 << Y) pred C -> Y pred Log2(C) + if (!CIsPowerOf2) { + // (1 << Y) < 30 -> Y <= 4 + // (1 << Y) <= 30 -> Y <= 4 + // (1 << Y) >= 30 -> Y > 4 + // (1 << Y) > 30 -> Y > 4 + if (Pred == ICmpInst::ICMP_ULT) + Pred = ICmpInst::ICMP_ULE; + else if (Pred == ICmpInst::ICMP_UGE) + Pred = ICmpInst::ICMP_UGT; + } + + // (1 << Y) >= 2147483648 -> Y >= 31 -> Y == 31 + // (1 << Y) < 2147483648 -> Y < 31 -> Y != 31 + unsigned CLog2 = C->logBase2(); + if (CLog2 == TypeBits - 1) { + if (Pred == ICmpInst::ICMP_UGE) + Pred = ICmpInst::ICMP_EQ; + else if (Pred == ICmpInst::ICMP_ULT) + Pred = ICmpInst::ICMP_NE; + } + return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2)); + } else if (Cmp.isSigned()) { + Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1); + if (C->isAllOnesValue()) { + // (1 << Y) <= -1 -> Y == 31 + if (Pred == ICmpInst::ICMP_SLE) + return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); + + // (1 << Y) > -1 -> Y != 31 + if (Pred == ICmpInst::ICMP_SGT) + return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); + } else if (!(*C)) { + // (1 << Y) < 0 -> Y == 31 + // (1 << Y) <= 0 -> Y == 31 + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) + return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); + + // (1 << Y) >= 0 -> Y != 31 + // (1 << Y) > 0 -> Y != 31 + if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) + return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); + } + } else if (Cmp.isEquality() && CIsPowerOf2) { + return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, C->logBase2())); + } + + return nullptr; +} + +/// Fold icmp (shl X, Y), C. +Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, + BinaryOperator *Shl, + const APInt *C) { + const APInt *ShiftVal; + if (Cmp.isEquality() && match(Shl->getOperand(0), m_APInt(ShiftVal))) + return foldICmpShlConstConst(Cmp, Shl->getOperand(1), *C, *ShiftVal); + + const APInt *ShiftAmt; + if (!match(Shl->getOperand(1), m_APInt(ShiftAmt))) + return foldICmpShlOne(Cmp, Shl, C); + + // Check that the shift amount is in range. If not, don't perform undefined + // shifts. When the shift is visited it will be simplified. + unsigned TypeBits = C->getBitWidth(); + if (ShiftAmt->uge(TypeBits)) + return nullptr; + + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = Shl->getOperand(0); + if (Cmp.isEquality()) { + // If the shift is NUW, then it is just shifting out zeros, no need for an + // AND. + Constant *LShrC = ConstantInt::get(Shl->getType(), C->lshr(*ShiftAmt)); + if (Shl->hasNoUnsignedWrap()) + return new ICmpInst(Pred, X, LShrC); + + // If the shift is NSW and we compare to 0, then it is just shifting out + // sign bits, no need for an AND either. + if (Shl->hasNoSignedWrap() && *C == 0) + return new ICmpInst(Pred, X, LShrC); + + if (Shl->hasOneUse()) { + // Otherwise strength reduce the shift into an and. + Constant *Mask = ConstantInt::get(Shl->getType(), + APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue())); + + Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask"); + return new ICmpInst(Pred, And, LShrC); + } + } + + // If this is a signed comparison to 0 and the shift is sign preserving, + // use the shift LHS operand instead; isSignTest may change 'Pred', so only + // do that if we're sure to not continue on in this function. + if (Shl->hasNoSignedWrap() && isSignTest(Pred, *C)) + return new ICmpInst(Pred, X, Constant::getNullValue(X->getType())); + + // Otherwise, if this is a comparison of the sign bit, simplify to and/test. + bool TrueIfSigned = false; + if (Shl->hasOneUse() && isSignBitCheck(Pred, *C, TrueIfSigned)) { + // (X << 31) <s 0 --> (X & 1) != 0 + Constant *Mask = ConstantInt::get( + X->getType(), + APInt::getOneBitSet(TypeBits, TypeBits - ShiftAmt->getZExtValue() - 1)); + Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask"); + return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ, + And, Constant::getNullValue(And->getType())); + } + + // When the shift is nuw and pred is >u or <=u, comparison only really happens + // in the pre-shifted bits. Since InstSimplify canoncalizes <=u into <u, the + // <=u case can be further converted to match <u (see below). + if (Shl->hasNoUnsignedWrap() && + (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT)) { + // Derivation for the ult case: + // (X << S) <=u C is equiv to X <=u (C >> S) for all C + // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u + // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0 + assert((Pred != ICmpInst::ICMP_ULT || C->ugt(0)) && + "Encountered `ult 0` that should have been eliminated by " + "InstSimplify."); + APInt ShiftedC = Pred == ICmpInst::ICMP_ULT ? (*C - 1).lshr(*ShiftAmt) + 1 + : C->lshr(*ShiftAmt); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), ShiftedC)); + } + + // Transform (icmp pred iM (shl iM %v, N), C) + // -> (icmp pred i(M-N) (trunc %v iM to i(M-N)), (trunc (C>>N)) + // Transform the shl to a trunc if (trunc (C>>N)) has no loss and M-N. + // This enables us to get rid of the shift in favor of a trunc which can be + // free on the target. It has the additional benefit of comparing to a + // smaller constant, which will be target friendly. + unsigned Amt = ShiftAmt->getLimitedValue(TypeBits - 1); + if (Shl->hasOneUse() && Amt != 0 && C->countTrailingZeros() >= Amt && + DL.isLegalInteger(TypeBits - Amt)) { + Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt); + if (X->getType()->isVectorTy()) + TruncTy = VectorType::get(TruncTy, X->getType()->getVectorNumElements()); + Constant *NewC = + ConstantInt::get(TruncTy, C->ashr(*ShiftAmt).trunc(TypeBits - Amt)); + return new ICmpInst(Pred, Builder->CreateTrunc(X, TruncTy), NewC); + } + + return nullptr; +} + +/// Fold icmp ({al}shr X, Y), C. +Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, + BinaryOperator *Shr, + const APInt *C) { + // An exact shr only shifts out zero bits, so: + // icmp eq/ne (shr X, Y), 0 --> icmp eq/ne X, 0 + Value *X = Shr->getOperand(0); + CmpInst::Predicate Pred = Cmp.getPredicate(); + if (Cmp.isEquality() && Shr->isExact() && Shr->hasOneUse() && *C == 0) + return new ICmpInst(Pred, X, Cmp.getOperand(1)); + + const APInt *ShiftVal; + if (Cmp.isEquality() && match(Shr->getOperand(0), m_APInt(ShiftVal))) + return foldICmpShrConstConst(Cmp, Shr->getOperand(1), *C, *ShiftVal); + + const APInt *ShiftAmt; + if (!match(Shr->getOperand(1), m_APInt(ShiftAmt))) + return nullptr; + + // Check that the shift amount is in range. If not, don't perform undefined + // shifts. When the shift is visited it will be simplified. + unsigned TypeBits = C->getBitWidth(); + unsigned ShAmtVal = ShiftAmt->getLimitedValue(TypeBits); + if (ShAmtVal >= TypeBits || ShAmtVal == 0) + return nullptr; + + bool IsAShr = Shr->getOpcode() == Instruction::AShr; + if (!Cmp.isEquality()) { + // If we have an unsigned comparison and an ashr, we can't simplify this. + // Similarly for signed comparisons with lshr. + if (Cmp.isSigned() != IsAShr) + return nullptr; + + // Otherwise, all lshr and most exact ashr's are equivalent to a udiv/sdiv + // by a power of 2. Since we already have logic to simplify these, + // transform to div and then simplify the resultant comparison. + if (IsAShr && (!Shr->isExact() || ShAmtVal == TypeBits - 1)) + return nullptr; + + // Revisit the shift (to delete it). + Worklist.Add(Shr); + + Constant *DivCst = ConstantInt::get( + Shr->getType(), APInt::getOneBitSet(TypeBits, ShAmtVal)); + + Value *Tmp = IsAShr ? Builder->CreateSDiv(X, DivCst, "", Shr->isExact()) + : Builder->CreateUDiv(X, DivCst, "", Shr->isExact()); + + Cmp.setOperand(0, Tmp); + + // If the builder folded the binop, just return it. + BinaryOperator *TheDiv = dyn_cast<BinaryOperator>(Tmp); + if (!TheDiv) + return &Cmp; + + // Otherwise, fold this div/compare. + assert(TheDiv->getOpcode() == Instruction::SDiv || + TheDiv->getOpcode() == Instruction::UDiv); + + Instruction *Res = foldICmpDivConstant(Cmp, TheDiv, C); + assert(Res && "This div/cst should have folded!"); + return Res; + } + + // Handle equality comparisons of shift-by-constant. + + // If the comparison constant changes with the shift, the comparison cannot + // succeed (bits of the comparison constant cannot match the shifted value). + // This should be known by InstSimplify and already be folded to true/false. + assert(((IsAShr && C->shl(ShAmtVal).ashr(ShAmtVal) == *C) || + (!IsAShr && C->shl(ShAmtVal).lshr(ShAmtVal) == *C)) && + "Expected icmp+shr simplify did not occur."); + + // Check if the bits shifted out are known to be zero. If so, we can compare + // against the unshifted value: + // (X & 4) >> 1 == 2 --> (X & 4) == 4. + Constant *ShiftedCmpRHS = ConstantInt::get(Shr->getType(), *C << ShAmtVal); + if (Shr->hasOneUse()) { + if (Shr->isExact()) + return new ICmpInst(Pred, X, ShiftedCmpRHS); + + // Otherwise strength reduce the shift into an 'and'. + APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); + Constant *Mask = ConstantInt::get(Shr->getType(), Val); + Value *And = Builder->CreateAnd(X, Mask, Shr->getName() + ".mask"); + return new ICmpInst(Pred, And, ShiftedCmpRHS); + } + + return nullptr; +} + +/// Fold icmp (udiv X, Y), C. +Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, + BinaryOperator *UDiv, + const APInt *C) { + const APInt *C2; + if (!match(UDiv->getOperand(0), m_APInt(C2))) + return nullptr; + + assert(C2 != 0 && "udiv 0, X should have been simplified already."); + + // (icmp ugt (udiv C2, Y), C) -> (icmp ule Y, C2/(C+1)) + Value *Y = UDiv->getOperand(1); + if (Cmp.getPredicate() == ICmpInst::ICMP_UGT) { + assert(!C->isMaxValue() && + "icmp ugt X, UINT_MAX should have been simplified already."); + return new ICmpInst(ICmpInst::ICMP_ULE, Y, + ConstantInt::get(Y->getType(), C2->udiv(*C + 1))); + } + + // (icmp ult (udiv C2, Y), C) -> (icmp ugt Y, C2/C) + if (Cmp.getPredicate() == ICmpInst::ICMP_ULT) { + assert(C != 0 && "icmp ult X, 0 should have been simplified already."); + return new ICmpInst(ICmpInst::ICMP_UGT, Y, + ConstantInt::get(Y->getType(), C2->udiv(*C))); + } + + return nullptr; +} + +/// Fold icmp ({su}div X, Y), C. +Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, + BinaryOperator *Div, + const APInt *C) { + // Fold: icmp pred ([us]div X, C2), C -> range test + // Fold this div into the comparison, producing a range check. + // Determine, based on the divide type, what the range is being + // checked. If there is an overflow on the low or high side, remember + // it, otherwise compute the range [low, hi) bounding the new value. + // See: InsertRangeTest above for the kinds of replacements possible. + const APInt *C2; + if (!match(Div->getOperand(1), m_APInt(C2))) + return nullptr; // FIXME: If the operand types don't match the type of the divide // then don't attempt this transform. The code below doesn't have the // logic to deal with a signed divide and an unsigned compare (and - // vice versa). This is because (x /s C1) <s C2 produces different - // results than (x /s C1) <u C2 or (x /u C1) <s C2 or even - // (x /u C1) <u C2. Simply casting the operands and result won't + // vice versa). This is because (x /s C2) <s C produces different + // results than (x /s C2) <u C or (x /u C2) <s C or even + // (x /u C2) <u C. Simply casting the operands and result won't // work. :( The if statement below tests that condition and bails // if it finds it. - bool DivIsSigned = DivI->getOpcode() == Instruction::SDiv; - if (!ICI.isEquality() && DivIsSigned != ICI.isSigned()) + bool DivIsSigned = Div->getOpcode() == Instruction::SDiv; + if (!Cmp.isEquality() && DivIsSigned != Cmp.isSigned()) + return nullptr; + + // The ProdOV computation fails on divide by 0 and divide by -1. Cases with + // INT_MIN will also fail if the divisor is 1. Although folds of all these + // division-by-constant cases should be present, we can not assert that they + // have happened before we reach this icmp instruction. + if (*C2 == 0 || *C2 == 1 || (DivIsSigned && C2->isAllOnesValue())) return nullptr; - if (DivRHS->isZero()) - return nullptr; // The ProdOV computation fails on divide by zero. - if (DivIsSigned && DivRHS->isAllOnesValue()) - return nullptr; // The overflow computation also screws up here - if (DivRHS->isOne()) { - // This eliminates some funny cases with INT_MIN. - ICI.setOperand(0, DivI->getOperand(0)); // X/1 == X. - return &ICI; - } - // Compute Prod = CI * DivRHS. We are essentially solving an equation - // of form X/C1=C2. We solve for X by multiplying C1 (DivRHS) and - // C2 (CI). By solving for X we can turn this into a range check - // instead of computing a divide. + // TODO: We could do all of the computations below using APInt. + Constant *CmpRHS = cast<Constant>(Cmp.getOperand(1)); + Constant *DivRHS = cast<Constant>(Div->getOperand(1)); + + // Compute Prod = CmpRHS * DivRHS. We are essentially solving an equation of + // form X / C2 = C. We solve for X by multiplying C2 (DivRHS) and C (CmpRHS). + // By solving for X, we can turn this into a range check instead of computing + // a divide. Constant *Prod = ConstantExpr::getMul(CmpRHS, DivRHS); - // Determine if the product overflows by seeing if the product is - // not equal to the divide. Make sure we do the same kind of divide - // as in the LHS instruction that we're folding. - bool ProdOV = (DivIsSigned ? ConstantExpr::getSDiv(Prod, DivRHS) : - ConstantExpr::getUDiv(Prod, DivRHS)) != CmpRHS; + // Determine if the product overflows by seeing if the product is not equal to + // the divide. Make sure we do the same kind of divide as in the LHS + // instruction that we're folding. + bool ProdOV = (DivIsSigned ? ConstantExpr::getSDiv(Prod, DivRHS) + : ConstantExpr::getUDiv(Prod, DivRHS)) != CmpRHS; - // Get the ICmp opcode - ICmpInst::Predicate Pred = ICI.getPredicate(); + ICmpInst::Predicate Pred = Cmp.getPredicate(); // If the division is known to be exact, then there is no remainder from the // divide, so the covered range size is unit, otherwise it is the divisor. - ConstantInt *RangeSize = DivI->isExact() ? getOne(Prod) : DivRHS; + Constant *RangeSize = + Div->isExact() ? ConstantInt::get(Div->getType(), 1) : DivRHS; // Figure out the interval that is being checked. For example, a comparison // like "X /u 5 == 0" is really checking that X is in the interval [0, 5). @@ -1245,1134 +2180,1094 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, if (!HiOverflow) { // If this is not an exact divide, then many values in the range collapse // to the same result value. - HiOverflow = AddWithOverflow(HiBound, LoBound, RangeSize, false); + HiOverflow = addWithOverflow(HiBound, LoBound, RangeSize, false); } - } else if (DivRHS->getValue().isStrictlyPositive()) { // Divisor is > 0. - if (CmpRHSV == 0) { // (X / pos) op 0 + } else if (C2->isStrictlyPositive()) { // Divisor is > 0. + if (*C == 0) { // (X / pos) op 0 // Can't overflow. e.g. X/2 op 0 --> [-1, 2) LoBound = ConstantExpr::getNeg(SubOne(RangeSize)); HiBound = RangeSize; - } else if (CmpRHSV.isStrictlyPositive()) { // (X / pos) op pos + } else if (C->isStrictlyPositive()) { // (X / pos) op pos LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) HiOverflow = LoOverflow = ProdOV; if (!HiOverflow) - HiOverflow = AddWithOverflow(HiBound, Prod, RangeSize, true); + HiOverflow = addWithOverflow(HiBound, Prod, RangeSize, true); } else { // (X / pos) op neg // e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14) HiBound = AddOne(Prod); LoOverflow = HiOverflow = ProdOV ? -1 : 0; if (!LoOverflow) { - ConstantInt *DivNeg =cast<ConstantInt>(ConstantExpr::getNeg(RangeSize)); - LoOverflow = AddWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0; + Constant *DivNeg = ConstantExpr::getNeg(RangeSize); + LoOverflow = addWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0; } } - } else if (DivRHS->isNegative()) { // Divisor is < 0. - if (DivI->isExact()) - RangeSize = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize)); - if (CmpRHSV == 0) { // (X / neg) op 0 + } else if (C2->isNegative()) { // Divisor is < 0. + if (Div->isExact()) + RangeSize = ConstantExpr::getNeg(RangeSize); + if (*C == 0) { // (X / neg) op 0 // e.g. X/-5 op 0 --> [-4, 5) LoBound = AddOne(RangeSize); - HiBound = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize)); + HiBound = ConstantExpr::getNeg(RangeSize); if (HiBound == DivRHS) { // -INTMIN = INTMIN HiOverflow = 1; // [INTMIN+1, overflow) HiBound = nullptr; // e.g. X/INTMIN = 0 --> X > INTMIN } - } else if (CmpRHSV.isStrictlyPositive()) { // (X / neg) op pos + } else if (C->isStrictlyPositive()) { // (X / neg) op pos // e.g. X/-5 op 3 --> [-19, -14) HiBound = AddOne(Prod); HiOverflow = LoOverflow = ProdOV ? -1 : 0; if (!LoOverflow) - LoOverflow = AddWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0; + LoOverflow = addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0; } else { // (X / neg) op neg LoBound = Prod; // e.g. X/-5 op -3 --> [15, 20) LoOverflow = HiOverflow = ProdOV; if (!HiOverflow) - HiOverflow = SubWithOverflow(HiBound, Prod, RangeSize, true); + HiOverflow = subWithOverflow(HiBound, Prod, RangeSize, true); } // Dividing by a negative swaps the condition. LT <-> GT Pred = ICmpInst::getSwappedPredicate(Pred); } - Value *X = DivI->getOperand(0); + Value *X = Div->getOperand(0); switch (Pred) { - default: llvm_unreachable("Unhandled icmp opcode!"); - case ICmpInst::ICMP_EQ: - if (LoOverflow && HiOverflow) - return replaceInstUsesWith(ICI, Builder->getFalse()); - if (HiOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : - ICmpInst::ICMP_UGE, X, LoBound); - if (LoOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : - ICmpInst::ICMP_ULT, X, HiBound); - return replaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound, - DivIsSigned, true)); - case ICmpInst::ICMP_NE: - if (LoOverflow && HiOverflow) - return replaceInstUsesWith(ICI, Builder->getTrue()); - if (HiOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : - ICmpInst::ICMP_ULT, X, LoBound); - if (LoOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : - ICmpInst::ICMP_UGE, X, HiBound); - return replaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound, - DivIsSigned, false)); - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_SLT: - if (LoOverflow == +1) // Low bound is greater than input range. - return replaceInstUsesWith(ICI, Builder->getTrue()); - if (LoOverflow == -1) // Low bound is less than input range. - return replaceInstUsesWith(ICI, Builder->getFalse()); - return new ICmpInst(Pred, X, LoBound); - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_SGT: - if (HiOverflow == +1) // High bound greater than input range. - return replaceInstUsesWith(ICI, Builder->getFalse()); - if (HiOverflow == -1) // High bound less than input range. - return replaceInstUsesWith(ICI, Builder->getTrue()); - if (Pred == ICmpInst::ICMP_UGT) - return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound); - return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); + default: llvm_unreachable("Unhandled icmp opcode!"); + case ICmpInst::ICMP_EQ: + if (LoOverflow && HiOverflow) + return replaceInstUsesWith(Cmp, Builder->getFalse()); + if (HiOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : + ICmpInst::ICMP_UGE, X, LoBound); + if (LoOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : + ICmpInst::ICMP_ULT, X, HiBound); + return replaceInstUsesWith( + Cmp, insertRangeTest(X, LoBound->getUniqueInteger(), + HiBound->getUniqueInteger(), DivIsSigned, true)); + case ICmpInst::ICMP_NE: + if (LoOverflow && HiOverflow) + return replaceInstUsesWith(Cmp, Builder->getTrue()); + if (HiOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : + ICmpInst::ICMP_ULT, X, LoBound); + if (LoOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : + ICmpInst::ICMP_UGE, X, HiBound); + return replaceInstUsesWith(Cmp, + insertRangeTest(X, LoBound->getUniqueInteger(), + HiBound->getUniqueInteger(), + DivIsSigned, false)); + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLT: + if (LoOverflow == +1) // Low bound is greater than input range. + return replaceInstUsesWith(Cmp, Builder->getTrue()); + if (LoOverflow == -1) // Low bound is less than input range. + return replaceInstUsesWith(Cmp, Builder->getFalse()); + return new ICmpInst(Pred, X, LoBound); + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGT: + if (HiOverflow == +1) // High bound greater than input range. + return replaceInstUsesWith(Cmp, Builder->getFalse()); + if (HiOverflow == -1) // High bound less than input range. + return replaceInstUsesWith(Cmp, Builder->getTrue()); + if (Pred == ICmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound); + return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); } + + return nullptr; } -/// Handle "icmp(([al]shr X, cst1), cst2)". -Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, - ConstantInt *ShAmt) { - const APInt &CmpRHSV = cast<ConstantInt>(ICI.getOperand(1))->getValue(); +/// Fold icmp (sub X, Y), C. +Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, + BinaryOperator *Sub, + const APInt *C) { + Value *X = Sub->getOperand(0), *Y = Sub->getOperand(1); + ICmpInst::Predicate Pred = Cmp.getPredicate(); - // Check that the shift amount is in range. If not, don't perform - // undefined shifts. When the shift is visited it will be - // simplified. - uint32_t TypeBits = CmpRHSV.getBitWidth(); - uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); - if (ShAmtVal >= TypeBits || ShAmtVal == 0) + // The following transforms are only worth it if the only user of the subtract + // is the icmp. + if (!Sub->hasOneUse()) return nullptr; - if (!ICI.isEquality()) { - // If we have an unsigned comparison and an ashr, we can't simplify this. - // Similarly for signed comparisons with lshr. - if (ICI.isSigned() != (Shr->getOpcode() == Instruction::AShr)) - return nullptr; - - // Otherwise, all lshr and most exact ashr's are equivalent to a udiv/sdiv - // by a power of 2. Since we already have logic to simplify these, - // transform to div and then simplify the resultant comparison. - if (Shr->getOpcode() == Instruction::AShr && - (!Shr->isExact() || ShAmtVal == TypeBits - 1)) - return nullptr; + if (Sub->hasNoSignedWrap()) { + // (icmp sgt (sub nsw X, Y), -1) -> (icmp sge X, Y) + if (Pred == ICmpInst::ICMP_SGT && C->isAllOnesValue()) + return new ICmpInst(ICmpInst::ICMP_SGE, X, Y); - // Revisit the shift (to delete it). - Worklist.Add(Shr); + // (icmp sgt (sub nsw X, Y), 0) -> (icmp sgt X, Y) + if (Pred == ICmpInst::ICMP_SGT && *C == 0) + return new ICmpInst(ICmpInst::ICMP_SGT, X, Y); - Constant *DivCst = - ConstantInt::get(Shr->getType(), APInt::getOneBitSet(TypeBits, ShAmtVal)); - - Value *Tmp = - Shr->getOpcode() == Instruction::AShr ? - Builder->CreateSDiv(Shr->getOperand(0), DivCst, "", Shr->isExact()) : - Builder->CreateUDiv(Shr->getOperand(0), DivCst, "", Shr->isExact()); - - ICI.setOperand(0, Tmp); - - // If the builder folded the binop, just return it. - BinaryOperator *TheDiv = dyn_cast<BinaryOperator>(Tmp); - if (!TheDiv) - return &ICI; - - // Otherwise, fold this div/compare. - assert(TheDiv->getOpcode() == Instruction::SDiv || - TheDiv->getOpcode() == Instruction::UDiv); + // (icmp slt (sub nsw X, Y), 0) -> (icmp slt X, Y) + if (Pred == ICmpInst::ICMP_SLT && *C == 0) + return new ICmpInst(ICmpInst::ICMP_SLT, X, Y); - Instruction *Res = FoldICmpDivCst(ICI, TheDiv, cast<ConstantInt>(DivCst)); - assert(Res && "This div/cst should have folded!"); - return Res; + // (icmp slt (sub nsw X, Y), 1) -> (icmp sle X, Y) + if (Pred == ICmpInst::ICMP_SLT && *C == 1) + return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); } - // If we are comparing against bits always shifted out, the - // comparison cannot succeed. - APInt Comp = CmpRHSV << ShAmtVal; - ConstantInt *ShiftedCmpRHS = Builder->getInt(Comp); - if (Shr->getOpcode() == Instruction::LShr) - Comp = Comp.lshr(ShAmtVal); - else - Comp = Comp.ashr(ShAmtVal); - - if (Comp != CmpRHSV) { // Comparing against a bit that we know is zero. - bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; - Constant *Cst = Builder->getInt1(IsICMP_NE); - return replaceInstUsesWith(ICI, Cst); - } + const APInt *C2; + if (!match(X, m_APInt(C2))) + return nullptr; - // Otherwise, check to see if the bits shifted out are known to be zero. - // If so, we can compare against the unshifted value: - // (X & 4) >> 1 == 2 --> (X & 4) == 4. - if (Shr->hasOneUse() && Shr->isExact()) - return new ICmpInst(ICI.getPredicate(), Shr->getOperand(0), ShiftedCmpRHS); + // C2 - Y <u C -> (Y | (C - 1)) == C2 + // iff (C2 & (C - 1)) == C - 1 and C is a power of 2 + if (Pred == ICmpInst::ICMP_ULT && C->isPowerOf2() && + (*C2 & (*C - 1)) == (*C - 1)) + return new ICmpInst(ICmpInst::ICMP_EQ, Builder->CreateOr(Y, *C - 1), X); - if (Shr->hasOneUse()) { - // Otherwise strength reduce the shift into an and. - APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); - Constant *Mask = Builder->getInt(Val); + // C2 - Y >u C -> (Y | C) != C2 + // iff C2 & C == C and C + 1 is a power of 2 + if (Pred == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && (*C2 & *C) == *C) + return new ICmpInst(ICmpInst::ICMP_NE, Builder->CreateOr(Y, *C), X); - Value *And = Builder->CreateAnd(Shr->getOperand(0), - Mask, Shr->getName()+".mask"); - return new ICmpInst(ICI.getPredicate(), And, ShiftedCmpRHS); - } return nullptr; } -/// Handle "(icmp eq/ne (ashr/lshr const2, A), const1)" -> -/// (icmp eq/ne A, Log2(const2/const1)) -> -/// (icmp eq/ne A, Log2(const2) - Log2(const1)). -Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, - ConstantInt *CI1, - ConstantInt *CI2) { - assert(I.isEquality() && "Cannot fold icmp gt/lt"); - - auto getConstant = [&I, this](bool IsTrue) { - if (I.getPredicate() == I.ICMP_NE) - IsTrue = !IsTrue; - return replaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); - }; - - auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { - if (I.getPredicate() == I.ICMP_NE) - Pred = CmpInst::getInversePredicate(Pred); - return new ICmpInst(Pred, LHS, RHS); - }; +/// Fold icmp (add X, Y), C. +Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, + BinaryOperator *Add, + const APInt *C) { + Value *Y = Add->getOperand(1); + const APInt *C2; + if (Cmp.isEquality() || !match(Y, m_APInt(C2))) + return nullptr; - const APInt &AP1 = CI1->getValue(); - const APInt &AP2 = CI2->getValue(); + // Fold icmp pred (add X, C2), C. + Value *X = Add->getOperand(0); + Type *Ty = Add->getType(); + auto CR = + ConstantRange::makeExactICmpRegion(Cmp.getPredicate(), *C).subtract(*C2); + const APInt &Upper = CR.getUpper(); + const APInt &Lower = CR.getLower(); + if (Cmp.isSigned()) { + if (Lower.isSignBit()) + return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantInt::get(Ty, Upper)); + if (Upper.isSignBit()) + return new ICmpInst(ICmpInst::ICMP_SGE, X, ConstantInt::get(Ty, Lower)); + } else { + if (Lower.isMinValue()) + return new ICmpInst(ICmpInst::ICMP_ULT, X, ConstantInt::get(Ty, Upper)); + if (Upper.isMinValue()) + return new ICmpInst(ICmpInst::ICMP_UGE, X, ConstantInt::get(Ty, Lower)); + } - // Don't bother doing any work for cases which InstSimplify handles. - if (AP2 == 0) + if (!Add->hasOneUse()) return nullptr; - bool IsAShr = isa<AShrOperator>(Op); - if (IsAShr) { - if (AP2.isAllOnesValue()) - return nullptr; - if (AP2.isNegative() != AP1.isNegative()) - return nullptr; - if (AP2.sgt(AP1)) - return nullptr; - } - if (!AP1) - // 'A' must be large enough to shift out the highest set bit. - return getICmp(I.ICMP_UGT, A, - ConstantInt::get(A->getType(), AP2.logBase2())); + // X+C <u C2 -> (X & -C2) == C + // iff C & (C2-1) == 0 + // C2 is a power of 2 + if (Cmp.getPredicate() == ICmpInst::ICMP_ULT && C->isPowerOf2() && + (*C2 & (*C - 1)) == 0) + return new ICmpInst(ICmpInst::ICMP_EQ, Builder->CreateAnd(X, -(*C)), + ConstantExpr::getNeg(cast<Constant>(Y))); - if (AP1 == AP2) - return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + // X+C >u C2 -> (X & ~C2) != C + // iff C & C2 == 0 + // C2+1 is a power of 2 + if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && + (*C2 & *C) == 0) + return new ICmpInst(ICmpInst::ICMP_NE, Builder->CreateAnd(X, ~(*C)), + ConstantExpr::getNeg(cast<Constant>(Y))); - int Shift; - if (IsAShr && AP1.isNegative()) - Shift = AP1.countLeadingOnes() - AP2.countLeadingOnes(); - else - Shift = AP1.countLeadingZeros() - AP2.countLeadingZeros(); + return nullptr; +} - if (Shift > 0) { - if (IsAShr && AP1 == AP2.ashr(Shift)) { - // There are multiple solutions if we are comparing against -1 and the LHS - // of the ashr is not a power of two. - if (AP1.isAllOnesValue() && !AP2.isPowerOf2()) - return getICmp(I.ICMP_UGE, A, ConstantInt::get(A->getType(), Shift)); - return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); - } else if (AP1 == AP2.lshr(Shift)) { - return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); +/// Try to fold integer comparisons with a constant operand: icmp Pred X, C +/// where X is some kind of instruction. +Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { + const APInt *C; + if (!match(Cmp.getOperand(1), m_APInt(C))) + return nullptr; + + BinaryOperator *BO; + if (match(Cmp.getOperand(0), m_BinOp(BO))) { + switch (BO->getOpcode()) { + case Instruction::Xor: + if (Instruction *I = foldICmpXorConstant(Cmp, BO, C)) + return I; + break; + case Instruction::And: + if (Instruction *I = foldICmpAndConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Or: + if (Instruction *I = foldICmpOrConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Mul: + if (Instruction *I = foldICmpMulConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Shl: + if (Instruction *I = foldICmpShlConstant(Cmp, BO, C)) + return I; + break; + case Instruction::LShr: + case Instruction::AShr: + if (Instruction *I = foldICmpShrConstant(Cmp, BO, C)) + return I; + break; + case Instruction::UDiv: + if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C)) + return I; + LLVM_FALLTHROUGH; + case Instruction::SDiv: + if (Instruction *I = foldICmpDivConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Sub: + if (Instruction *I = foldICmpSubConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Add: + if (Instruction *I = foldICmpAddConstant(Cmp, BO, C)) + return I; + break; + default: + break; } + // TODO: These folds could be refactored to be part of the above calls. + if (Instruction *I = foldICmpBinOpEqualityWithConstant(Cmp, BO, C)) + return I; } - // Shifting const2 will never be equal to const1. - return getConstant(false); -} - -/// Handle "(icmp eq/ne (shl const2, A), const1)" -> -/// (icmp eq/ne A, TrailingZeros(const1) - TrailingZeros(const2)). -Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, - ConstantInt *CI1, - ConstantInt *CI2) { - assert(I.isEquality() && "Cannot fold icmp gt/lt"); - auto getConstant = [&I, this](bool IsTrue) { - if (I.getPredicate() == I.ICMP_NE) - IsTrue = !IsTrue; - return replaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); - }; + Instruction *LHSI; + if (match(Cmp.getOperand(0), m_Instruction(LHSI)) && + LHSI->getOpcode() == Instruction::Trunc) + if (Instruction *I = foldICmpTruncConstant(Cmp, LHSI, C)) + return I; - auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { - if (I.getPredicate() == I.ICMP_NE) - Pred = CmpInst::getInversePredicate(Pred); - return new ICmpInst(Pred, LHS, RHS); - }; + if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, C)) + return I; - const APInt &AP1 = CI1->getValue(); - const APInt &AP2 = CI2->getValue(); + return nullptr; +} - // Don't bother doing any work for cases which InstSimplify handles. - if (AP2 == 0) +/// Fold an icmp equality instruction with binary operator LHS and constant RHS: +/// icmp eq/ne BO, C. +Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, + BinaryOperator *BO, + const APInt *C) { + // TODO: Some of these folds could work with arbitrary constants, but this + // function is limited to scalar and vector splat constants. + if (!Cmp.isEquality()) return nullptr; - unsigned AP2TrailingZeros = AP2.countTrailingZeros(); + ICmpInst::Predicate Pred = Cmp.getPredicate(); + bool isICMP_NE = Pred == ICmpInst::ICMP_NE; + Constant *RHS = cast<Constant>(Cmp.getOperand(1)); + Value *BOp0 = BO->getOperand(0), *BOp1 = BO->getOperand(1); - if (!AP1 && AP2TrailingZeros != 0) - return getICmp(I.ICMP_UGE, A, - ConstantInt::get(A->getType(), AP2.getBitWidth() - AP2TrailingZeros)); + switch (BO->getOpcode()) { + case Instruction::SRem: + // If we have a signed (X % (2^c)) == 0, turn it into an unsigned one. + if (*C == 0 && BO->hasOneUse()) { + const APInt *BOC; + if (match(BOp1, m_APInt(BOC)) && BOC->sgt(1) && BOC->isPowerOf2()) { + Value *NewRem = Builder->CreateURem(BOp0, BOp1, BO->getName()); + return new ICmpInst(Pred, NewRem, + Constant::getNullValue(BO->getType())); + } + } + break; + case Instruction::Add: { + // Replace ((add A, B) != C) with (A != C-B) if B & C are constants. + const APInt *BOC; + if (match(BOp1, m_APInt(BOC))) { + if (BO->hasOneUse()) { + Constant *SubC = ConstantExpr::getSub(RHS, cast<Constant>(BOp1)); + return new ICmpInst(Pred, BOp0, SubC); + } + } else if (*C == 0) { + // Replace ((add A, B) != 0) with (A != -B) if A or B is + // efficiently invertible, or if the add has just this one use. + if (Value *NegVal = dyn_castNegVal(BOp1)) + return new ICmpInst(Pred, BOp0, NegVal); + if (Value *NegVal = dyn_castNegVal(BOp0)) + return new ICmpInst(Pred, NegVal, BOp1); + if (BO->hasOneUse()) { + Value *Neg = Builder->CreateNeg(BOp1); + Neg->takeName(BO); + return new ICmpInst(Pred, BOp0, Neg); + } + } + break; + } + case Instruction::Xor: + if (BO->hasOneUse()) { + if (Constant *BOC = dyn_cast<Constant>(BOp1)) { + // For the xor case, we can xor two constants together, eliminating + // the explicit xor. + return new ICmpInst(Pred, BOp0, ConstantExpr::getXor(RHS, BOC)); + } else if (*C == 0) { + // Replace ((xor A, B) != 0) with (A != B) + return new ICmpInst(Pred, BOp0, BOp1); + } + } + break; + case Instruction::Sub: + if (BO->hasOneUse()) { + const APInt *BOC; + if (match(BOp0, m_APInt(BOC))) { + // Replace ((sub BOC, B) != C) with (B != BOC-C). + Constant *SubC = ConstantExpr::getSub(cast<Constant>(BOp0), RHS); + return new ICmpInst(Pred, BOp1, SubC); + } else if (*C == 0) { + // Replace ((sub A, B) != 0) with (A != B). + return new ICmpInst(Pred, BOp0, BOp1); + } + } + break; + case Instruction::Or: { + const APInt *BOC; + if (match(BOp1, m_APInt(BOC)) && BO->hasOneUse() && RHS->isAllOnesValue()) { + // Comparing if all bits outside of a constant mask are set? + // Replace (X | C) == -1 with (X & ~C) == ~C. + // This removes the -1 constant. + Constant *NotBOC = ConstantExpr::getNot(cast<Constant>(BOp1)); + Value *And = Builder->CreateAnd(BOp0, NotBOC); + return new ICmpInst(Pred, And, NotBOC); + } + break; + } + case Instruction::And: { + const APInt *BOC; + if (match(BOp1, m_APInt(BOC))) { + // If we have ((X & C) == C), turn it into ((X & C) != 0). + if (C == BOC && C->isPowerOf2()) + return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, + BO, Constant::getNullValue(RHS->getType())); - if (AP1 == AP2) - return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + // Don't perform the following transforms if the AND has multiple uses + if (!BO->hasOneUse()) + break; - // Get the distance between the lowest bits that are set. - int Shift = AP1.countTrailingZeros() - AP2TrailingZeros; + // Replace (and X, (1 << size(X)-1) != 0) with x s< 0 + if (BOC->isSignBit()) { + Constant *Zero = Constant::getNullValue(BOp0->getType()); + auto NewPred = isICMP_NE ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; + return new ICmpInst(NewPred, BOp0, Zero); + } - if (Shift > 0 && AP2.shl(Shift) == AP1) - return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + // ((X & ~7) == 0) --> X < 8 + if (*C == 0 && (~(*BOC) + 1).isPowerOf2()) { + Constant *NegBOC = ConstantExpr::getNeg(cast<Constant>(BOp1)); + auto NewPred = isICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; + return new ICmpInst(NewPred, BOp0, NegBOC); + } + } + break; + } + case Instruction::Mul: + if (*C == 0 && BO->hasNoSignedWrap()) { + const APInt *BOC; + if (match(BOp1, m_APInt(BOC)) && *BOC != 0) { + // The trivial case (mul X, 0) is handled by InstSimplify. + // General case : (mul X, C) != 0 iff X != 0 + // (mul X, C) == 0 iff X == 0 + return new ICmpInst(Pred, BOp0, Constant::getNullValue(RHS->getType())); + } + } + break; + case Instruction::UDiv: + if (*C == 0) { + // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) + auto NewPred = isICMP_NE ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; + return new ICmpInst(NewPred, BOp1, BOp0); + } + break; + default: + break; + } + return nullptr; +} - // Shifting const2 will never be equal to const1. - return getConstant(false); +/// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. +Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, + const APInt *C) { + IntrinsicInst *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0)); + if (!II || !Cmp.isEquality()) + return nullptr; + + // Handle icmp {eq|ne} <intrinsic>, intcst. + switch (II->getIntrinsicID()) { + case Intrinsic::bswap: + Worklist.Add(II); + Cmp.setOperand(0, II->getArgOperand(0)); + Cmp.setOperand(1, Builder->getInt(C->byteSwap())); + return &Cmp; + case Intrinsic::ctlz: + case Intrinsic::cttz: + // ctz(A) == bitwidth(A) -> A == 0 and likewise for != + if (*C == C->getBitWidth()) { + Worklist.Add(II); + Cmp.setOperand(0, II->getArgOperand(0)); + Cmp.setOperand(1, ConstantInt::getNullValue(II->getType())); + return &Cmp; + } + break; + case Intrinsic::ctpop: { + // popcount(A) == 0 -> A == 0 and likewise for != + // popcount(A) == bitwidth(A) -> A == -1 and likewise for != + bool IsZero = *C == 0; + if (IsZero || *C == C->getBitWidth()) { + Worklist.Add(II); + Cmp.setOperand(0, II->getArgOperand(0)); + auto *NewOp = IsZero ? Constant::getNullValue(II->getType()) + : Constant::getAllOnesValue(II->getType()); + Cmp.setOperand(1, NewOp); + return &Cmp; + } + break; + } + default: + break; + } + return nullptr; } -/// Handle "icmp (instr, intcst)". -Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, - Instruction *LHSI, - ConstantInt *RHS) { - const APInt &RHSV = RHS->getValue(); +/// Handle icmp with constant (but not simple integer constant) RHS. +Instruction *InstCombiner::foldICmpInstWithConstantNotInt(ICmpInst &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Constant *RHSC = dyn_cast<Constant>(Op1); + Instruction *LHSI = dyn_cast<Instruction>(Op0); + if (!RHSC || !LHSI) + return nullptr; switch (LHSI->getOpcode()) { - case Instruction::Trunc: - if (RHS->isOne() && RHSV.getBitWidth() > 1) { - // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 - Value *V = nullptr; - if (ICI.getPredicate() == ICmpInst::ICMP_SLT && - match(LHSI->getOperand(0), m_Signum(m_Value(V)))) - return new ICmpInst(ICmpInst::ICMP_SLT, V, - ConstantInt::get(V->getType(), 1)); + case Instruction::GetElementPtr: + // icmp pred GEP (P, int 0, int 0, int 0), null -> icmp pred P, null + if (RHSC->isNullValue() && + cast<GetElementPtrInst>(LHSI)->hasAllZeroIndices()) + return new ICmpInst( + I.getPredicate(), LHSI->getOperand(0), + Constant::getNullValue(LHSI->getOperand(0)->getType())); + break; + case Instruction::PHI: + // Only fold icmp into the PHI if the phi and icmp are in the same + // block. If in the same block, we're encouraging jump threading. If + // not, we are just pessimizing the code by making an i1 phi. + if (LHSI->getParent() == I.getParent()) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + break; + case Instruction::Select: { + // If either operand of the select is a constant, we can fold the + // comparison into the select arms, which will cause one to be + // constant folded and the select turned into a bitwise or. + Value *Op1 = nullptr, *Op2 = nullptr; + ConstantInt *CI = nullptr; + if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1))) { + Op1 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); + CI = dyn_cast<ConstantInt>(Op1); + } + if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(2))) { + Op2 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); + CI = dyn_cast<ConstantInt>(Op2); } - if (ICI.isEquality() && LHSI->hasOneUse()) { - // Simplify icmp eq (trunc x to i8), 42 -> icmp eq x, 42|highbits if all - // of the high bits truncated out of x are known. - unsigned DstBits = LHSI->getType()->getPrimitiveSizeInBits(), - SrcBits = LHSI->getOperand(0)->getType()->getPrimitiveSizeInBits(); - APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0); - computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne, 0, &ICI); - // If all the high bits are known, we can do this xform. - if ((KnownZero|KnownOne).countLeadingOnes() >= SrcBits-DstBits) { - // Pull in the high bits from known-ones set. - APInt NewRHS = RHS->getValue().zext(SrcBits); - NewRHS |= KnownOne & APInt::getHighBitsSet(SrcBits, SrcBits-DstBits); - return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), - Builder->getInt(NewRHS)); - } + // We only want to perform this transformation if it will not lead to + // additional code. This is true if either both sides of the select + // fold to a constant (in which case the icmp is replaced with a select + // which will usually simplify) or this is the only user of the + // select (in which case we are trading a select+icmp for a simpler + // select+icmp) or all uses of the select can be replaced based on + // dominance information ("Global cases"). + bool Transform = false; + if (Op1 && Op2) + Transform = true; + else if (Op1 || Op2) { + // Local case + if (LHSI->hasOneUse()) + Transform = true; + // Global cases + else if (CI && !CI->isZero()) + // When Op1 is constant try replacing select with second operand. + // Otherwise Op2 is constant and try replacing select with first + // operand. + Transform = + replacedSelectWithOperand(cast<SelectInst>(LHSI), &I, Op1 ? 2 : 1); } + if (Transform) { + if (!Op1) + Op1 = Builder->CreateICmp(I.getPredicate(), LHSI->getOperand(1), RHSC, + I.getName()); + if (!Op2) + Op2 = Builder->CreateICmp(I.getPredicate(), LHSI->getOperand(2), RHSC, + I.getName()); + return SelectInst::Create(LHSI->getOperand(0), Op1, Op2); + } + break; + } + case Instruction::IntToPtr: + // icmp pred inttoptr(X), null -> icmp pred X, 0 + if (RHSC->isNullValue() && + DL.getIntPtrType(RHSC->getType()) == LHSI->getOperand(0)->getType()) + return new ICmpInst( + I.getPredicate(), LHSI->getOperand(0), + Constant::getNullValue(LHSI->getOperand(0)->getType())); break; - case Instruction::Xor: // (icmp pred (xor X, XorCst), CI) - if (ConstantInt *XorCst = dyn_cast<ConstantInt>(LHSI->getOperand(1))) { - // If this is a comparison that tests the signbit (X < 0) or (x > -1), - // fold the xor. - if ((ICI.getPredicate() == ICmpInst::ICMP_SLT && RHSV == 0) || - (ICI.getPredicate() == ICmpInst::ICMP_SGT && RHSV.isAllOnesValue())) { - Value *CompareVal = LHSI->getOperand(0); + case Instruction::Load: + // Try to optimize things like "A[i] > 4" to index computations. + if (GetElementPtrInst *GEP = + dyn_cast<GetElementPtrInst>(LHSI->getOperand(0))) { + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) + if (GV->isConstant() && GV->hasDefinitiveInitializer() && + !cast<LoadInst>(LHSI)->isVolatile()) + if (Instruction *Res = foldCmpLoadFromIndexedGlobal(GEP, GV, I)) + return Res; + } + break; + } - // If the sign bit of the XorCst is not set, there is no change to - // the operation, just stop using the Xor. - if (!XorCst->isNegative()) { - ICI.setOperand(0, CompareVal); - Worklist.Add(LHSI); - return &ICI; - } + return nullptr; +} - // Was the old condition true if the operand is positive? - bool isTrueIfPositive = ICI.getPredicate() == ICmpInst::ICMP_SGT; +/// Try to fold icmp (binop), X or icmp X, (binop). +Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - // If so, the new one isn't. - isTrueIfPositive ^= true; + // Special logic for binary operators. + BinaryOperator *BO0 = dyn_cast<BinaryOperator>(Op0); + BinaryOperator *BO1 = dyn_cast<BinaryOperator>(Op1); + if (!BO0 && !BO1) + return nullptr; - if (isTrueIfPositive) - return new ICmpInst(ICmpInst::ICMP_SGT, CompareVal, - SubOne(RHS)); - else - return new ICmpInst(ICmpInst::ICMP_SLT, CompareVal, - AddOne(RHS)); - } + CmpInst::Predicate Pred = I.getPredicate(); + bool NoOp0WrapProblem = false, NoOp1WrapProblem = false; + if (BO0 && isa<OverflowingBinaryOperator>(BO0)) + NoOp0WrapProblem = + ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && BO0->hasNoUnsignedWrap()) || + (CmpInst::isSigned(Pred) && BO0->hasNoSignedWrap()); + if (BO1 && isa<OverflowingBinaryOperator>(BO1)) + NoOp1WrapProblem = + ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && BO1->hasNoUnsignedWrap()) || + (CmpInst::isSigned(Pred) && BO1->hasNoSignedWrap()); - if (LHSI->hasOneUse()) { - // (icmp u/s (xor A SignBit), C) -> (icmp s/u A, (xor C SignBit)) - if (!ICI.isEquality() && XorCst->getValue().isSignBit()) { - const APInt &SignBit = XorCst->getValue(); - ICmpInst::Predicate Pred = ICI.isSigned() - ? ICI.getUnsignedPredicate() - : ICI.getSignedPredicate(); - return new ICmpInst(Pred, LHSI->getOperand(0), - Builder->getInt(RHSV ^ SignBit)); - } + // Analyze the case when either Op0 or Op1 is an add instruction. + // Op0 = A + B (or A and B are null); Op1 = C + D (or C and D are null). + Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; + if (BO0 && BO0->getOpcode() == Instruction::Add) { + A = BO0->getOperand(0); + B = BO0->getOperand(1); + } + if (BO1 && BO1->getOpcode() == Instruction::Add) { + C = BO1->getOperand(0); + D = BO1->getOperand(1); + } - // (icmp u/s (xor A ~SignBit), C) -> (icmp s/u (xor C ~SignBit), A) - if (!ICI.isEquality() && XorCst->isMaxValue(true)) { - const APInt &NotSignBit = XorCst->getValue(); - ICmpInst::Predicate Pred = ICI.isSigned() - ? ICI.getUnsignedPredicate() - : ICI.getSignedPredicate(); - Pred = ICI.getSwappedPredicate(Pred); - return new ICmpInst(Pred, LHSI->getOperand(0), - Builder->getInt(RHSV ^ NotSignBit)); - } - } + // icmp (X+cst) < 0 --> X < -cst + if (NoOp0WrapProblem && ICmpInst::isSigned(Pred) && match(Op1, m_Zero())) + if (ConstantInt *RHSC = dyn_cast_or_null<ConstantInt>(B)) + if (!RHSC->isMinValue(/*isSigned=*/true)) + return new ICmpInst(Pred, A, ConstantExpr::getNeg(RHSC)); + + // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. + if ((A == Op1 || B == Op1) && NoOp0WrapProblem) + return new ICmpInst(Pred, A == Op1 ? B : A, + Constant::getNullValue(Op1->getType())); - // (icmp ugt (xor X, C), ~C) -> (icmp ult X, C) - // iff -C is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_UGT && - XorCst->getValue() == ~RHSV && (RHSV + 1).isPowerOf2()) - return new ICmpInst(ICmpInst::ICMP_ULT, LHSI->getOperand(0), XorCst); + // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow. + if ((C == Op0 || D == Op0) && NoOp1WrapProblem) + return new ICmpInst(Pred, Constant::getNullValue(Op0->getType()), + C == Op0 ? D : C); - // (icmp ult (xor X, C), -C) -> (icmp uge X, C) - // iff -C is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_ULT && - XorCst->getValue() == -RHSV && RHSV.isPowerOf2()) - return new ICmpInst(ICmpInst::ICMP_UGE, LHSI->getOperand(0), XorCst); + // icmp (X+Y), (X+Z) -> icmp Y, Z for equalities or if there is no overflow. + if (A && C && (A == C || A == D || B == C || B == D) && NoOp0WrapProblem && + NoOp1WrapProblem && + // Try not to increase register pressure. + BO0->hasOneUse() && BO1->hasOneUse()) { + // Determine Y and Z in the form icmp (X+Y), (X+Z). + Value *Y, *Z; + if (A == C) { + // C + B == C + D -> B == D + Y = B; + Z = D; + } else if (A == D) { + // D + B == C + D -> B == C + Y = B; + Z = C; + } else if (B == C) { + // A + C == C + D -> A == D + Y = A; + Z = D; + } else { + assert(B == D); + // A + D == C + D -> A == C + Y = A; + Z = C; } - break; - case Instruction::And: // (icmp pred (and X, AndCst), RHS) - if (LHSI->hasOneUse() && isa<ConstantInt>(LHSI->getOperand(1)) && - LHSI->getOperand(0)->hasOneUse()) { - ConstantInt *AndCst = cast<ConstantInt>(LHSI->getOperand(1)); + return new ICmpInst(Pred, Y, Z); + } - // If the LHS is an AND of a truncating cast, we can widen the - // and/compare to be the input width without changing the value - // produced, eliminating a cast. - if (TruncInst *Cast = dyn_cast<TruncInst>(LHSI->getOperand(0))) { - // We can do this transformation if either the AND constant does not - // have its sign bit set or if it is an equality comparison. - // Extending a relational comparison when we're checking the sign - // bit would not work. - if (ICI.isEquality() || - (!AndCst->isNegative() && RHSV.isNonNegative())) { - Value *NewAnd = - Builder->CreateAnd(Cast->getOperand(0), - ConstantExpr::getZExt(AndCst, Cast->getSrcTy())); - NewAnd->takeName(LHSI); - return new ICmpInst(ICI.getPredicate(), NewAnd, - ConstantExpr::getZExt(RHS, Cast->getSrcTy())); - } - } + // icmp slt (X + -1), Y -> icmp sle X, Y + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLT && + match(B, m_AllOnes())) + return new ICmpInst(CmpInst::ICMP_SLE, A, Op1); - // If the LHS is an AND of a zext, and we have an equality compare, we can - // shrink the and/compare to the smaller type, eliminating the cast. - if (ZExtInst *Cast = dyn_cast<ZExtInst>(LHSI->getOperand(0))) { - IntegerType *Ty = cast<IntegerType>(Cast->getSrcTy()); - // Make sure we don't compare the upper bits, SimplifyDemandedBits - // should fold the icmp to true/false in that case. - if (ICI.isEquality() && RHSV.getActiveBits() <= Ty->getBitWidth()) { - Value *NewAnd = - Builder->CreateAnd(Cast->getOperand(0), - ConstantExpr::getTrunc(AndCst, Ty)); - NewAnd->takeName(LHSI); - return new ICmpInst(ICI.getPredicate(), NewAnd, - ConstantExpr::getTrunc(RHS, Ty)); - } - } + // icmp sge (X + -1), Y -> icmp sgt X, Y + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGE && + match(B, m_AllOnes())) + return new ICmpInst(CmpInst::ICMP_SGT, A, Op1); - // If this is: (X >> C1) & C2 != C3 (where any shift and any compare - // could exist), turn it into (X & (C2 << C1)) != (C3 << C1). This - // happens a LOT in code produced by the C front-end, for bitfield - // access. - BinaryOperator *Shift = dyn_cast<BinaryOperator>(LHSI->getOperand(0)); - if (Shift && !Shift->isShift()) - Shift = nullptr; + // icmp sle (X + 1), Y -> icmp slt X, Y + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLE && match(B, m_One())) + return new ICmpInst(CmpInst::ICMP_SLT, A, Op1); - ConstantInt *ShAmt; - ShAmt = Shift ? dyn_cast<ConstantInt>(Shift->getOperand(1)) : nullptr; + // icmp sgt (X + 1), Y -> icmp sge X, Y + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGT && match(B, m_One())) + return new ICmpInst(CmpInst::ICMP_SGE, A, Op1); - // This seemingly simple opportunity to fold away a shift turns out to - // be rather complicated. See PR17827 - // ( http://llvm.org/bugs/show_bug.cgi?id=17827 ) for details. - if (ShAmt) { - bool CanFold = false; - unsigned ShiftOpcode = Shift->getOpcode(); - if (ShiftOpcode == Instruction::AShr) { - // There may be some constraints that make this possible, - // but nothing simple has been discovered yet. - CanFold = false; - } else if (ShiftOpcode == Instruction::Shl) { - // For a left shift, we can fold if the comparison is not signed. - // We can also fold a signed comparison if the mask value and - // comparison value are not negative. These constraints may not be - // obvious, but we can prove that they are correct using an SMT - // solver. - if (!ICI.isSigned() || (!AndCst->isNegative() && !RHS->isNegative())) - CanFold = true; - } else if (ShiftOpcode == Instruction::LShr) { - // For a logical right shift, we can fold if the comparison is not - // signed. We can also fold a signed comparison if the shifted mask - // value and the shifted comparison value are not negative. - // These constraints may not be obvious, but we can prove that they - // are correct using an SMT solver. - if (!ICI.isSigned()) - CanFold = true; - else { - ConstantInt *ShiftedAndCst = - cast<ConstantInt>(ConstantExpr::getShl(AndCst, ShAmt)); - ConstantInt *ShiftedRHSCst = - cast<ConstantInt>(ConstantExpr::getShl(RHS, ShAmt)); - - if (!ShiftedAndCst->isNegative() && !ShiftedRHSCst->isNegative()) - CanFold = true; - } - } + // icmp sgt X, (Y + -1) -> icmp sge X, Y + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGT && + match(D, m_AllOnes())) + return new ICmpInst(CmpInst::ICMP_SGE, Op0, C); - if (CanFold) { - Constant *NewCst; - if (ShiftOpcode == Instruction::Shl) - NewCst = ConstantExpr::getLShr(RHS, ShAmt); - else - NewCst = ConstantExpr::getShl(RHS, ShAmt); + // icmp sle X, (Y + -1) -> icmp slt X, Y + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLE && + match(D, m_AllOnes())) + return new ICmpInst(CmpInst::ICMP_SLT, Op0, C); - // Check to see if we are shifting out any of the bits being - // compared. - if (ConstantExpr::get(ShiftOpcode, NewCst, ShAmt) != RHS) { - // If we shifted bits out, the fold is not going to work out. - // As a special case, check to see if this means that the - // result is always true or false now. - if (ICI.getPredicate() == ICmpInst::ICMP_EQ) - return replaceInstUsesWith(ICI, Builder->getFalse()); - if (ICI.getPredicate() == ICmpInst::ICMP_NE) - return replaceInstUsesWith(ICI, Builder->getTrue()); + // icmp sge X, (Y + 1) -> icmp sgt X, Y + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGE && match(D, m_One())) + return new ICmpInst(CmpInst::ICMP_SGT, Op0, C); + + // icmp slt X, (Y + 1) -> icmp sle X, Y + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT && match(D, m_One())) + return new ICmpInst(CmpInst::ICMP_SLE, Op0, C); + + // if C1 has greater magnitude than C2: + // icmp (X + C1), (Y + C2) -> icmp (X + C3), Y + // s.t. C3 = C1 - C2 + // + // if C2 has greater magnitude than C1: + // icmp (X + C1), (Y + C2) -> icmp X, (Y + C3) + // s.t. C3 = C2 - C1 + if (A && C && NoOp0WrapProblem && NoOp1WrapProblem && + (BO0->hasOneUse() || BO1->hasOneUse()) && !I.isUnsigned()) + if (ConstantInt *C1 = dyn_cast<ConstantInt>(B)) + if (ConstantInt *C2 = dyn_cast<ConstantInt>(D)) { + const APInt &AP1 = C1->getValue(); + const APInt &AP2 = C2->getValue(); + if (AP1.isNegative() == AP2.isNegative()) { + APInt AP1Abs = C1->getValue().abs(); + APInt AP2Abs = C2->getValue().abs(); + if (AP1Abs.uge(AP2Abs)) { + ConstantInt *C3 = Builder->getInt(AP1 - AP2); + Value *NewAdd = Builder->CreateNSWAdd(A, C3); + return new ICmpInst(Pred, NewAdd, C); } else { - ICI.setOperand(1, NewCst); - Constant *NewAndCst; - if (ShiftOpcode == Instruction::Shl) - NewAndCst = ConstantExpr::getLShr(AndCst, ShAmt); - else - NewAndCst = ConstantExpr::getShl(AndCst, ShAmt); - LHSI->setOperand(1, NewAndCst); - LHSI->setOperand(0, Shift->getOperand(0)); - Worklist.Add(Shift); // Shift is dead. - return &ICI; + ConstantInt *C3 = Builder->getInt(AP2 - AP1); + Value *NewAdd = Builder->CreateNSWAdd(C, C3); + return new ICmpInst(Pred, A, NewAdd); } } } - // Turn ((X >> Y) & C) == 0 into (X & (C << Y)) == 0. The later is - // preferable because it allows the C<<Y expression to be hoisted out - // of a loop if Y is invariant and X is not. - if (Shift && Shift->hasOneUse() && RHSV == 0 && - ICI.isEquality() && !Shift->isArithmeticShift() && - !isa<Constant>(Shift->getOperand(0))) { - // Compute C << Y. - Value *NS; - if (Shift->getOpcode() == Instruction::LShr) { - NS = Builder->CreateShl(AndCst, Shift->getOperand(1)); - } else { - // Insert a logical shift. - NS = Builder->CreateLShr(AndCst, Shift->getOperand(1)); - } + // Analyze the case when either Op0 or Op1 is a sub instruction. + // Op0 = A - B (or A and B are null); Op1 = C - D (or C and D are null). + A = nullptr; + B = nullptr; + C = nullptr; + D = nullptr; + if (BO0 && BO0->getOpcode() == Instruction::Sub) { + A = BO0->getOperand(0); + B = BO0->getOperand(1); + } + if (BO1 && BO1->getOpcode() == Instruction::Sub) { + C = BO1->getOperand(0); + D = BO1->getOperand(1); + } - // Compute X & (C << Y). - Value *NewAnd = - Builder->CreateAnd(Shift->getOperand(0), NS, LHSI->getName()); + // icmp (X-Y), X -> icmp 0, Y for equalities or if there is no overflow. + if (A == Op1 && NoOp0WrapProblem) + return new ICmpInst(Pred, Constant::getNullValue(Op1->getType()), B); - ICI.setOperand(0, NewAnd); - return &ICI; - } + // icmp X, (X-Y) -> icmp Y, 0 for equalities or if there is no overflow. + if (C == Op0 && NoOp1WrapProblem) + return new ICmpInst(Pred, D, Constant::getNullValue(Op0->getType())); - // (icmp pred (and (or (lshr X, Y), X), 1), 0) --> - // (icmp pred (and X, (or (shl 1, Y), 1), 0)) - // - // iff pred isn't signed - { - Value *X, *Y, *LShr; - if (!ICI.isSigned() && RHSV == 0) { - if (match(LHSI->getOperand(1), m_One())) { - Constant *One = cast<Constant>(LHSI->getOperand(1)); - Value *Or = LHSI->getOperand(0); - if (match(Or, m_Or(m_Value(LShr), m_Value(X))) && - match(LShr, m_LShr(m_Specific(X), m_Value(Y)))) { - unsigned UsesRemoved = 0; - if (LHSI->hasOneUse()) - ++UsesRemoved; - if (Or->hasOneUse()) - ++UsesRemoved; - if (LShr->hasOneUse()) - ++UsesRemoved; - Value *NewOr = nullptr; - // Compute X & ((1 << Y) | 1) - if (auto *C = dyn_cast<Constant>(Y)) { - if (UsesRemoved >= 1) - NewOr = - ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One); - } else { - if (UsesRemoved >= 3) - NewOr = Builder->CreateOr(Builder->CreateShl(One, Y, - LShr->getName(), - /*HasNUW=*/true), - One, Or->getName()); - } - if (NewOr) { - Value *NewAnd = Builder->CreateAnd(X, NewOr, LHSI->getName()); - ICI.setOperand(0, NewAnd); - return &ICI; - } - } - } - } - } + // icmp (Y-X), (Z-X) -> icmp Y, Z for equalities or if there is no overflow. + if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem && + // Try not to increase register pressure. + BO0->hasOneUse() && BO1->hasOneUse()) + return new ICmpInst(Pred, A, C); - // Replace ((X & AndCst) > RHSV) with ((X & AndCst) != 0), if any - // bit set in (X & AndCst) will produce a result greater than RHSV. - if (ICI.getPredicate() == ICmpInst::ICMP_UGT) { - unsigned NTZ = AndCst->getValue().countTrailingZeros(); - if ((NTZ < AndCst->getBitWidth()) && - APInt::getOneBitSet(AndCst->getBitWidth(), NTZ).ugt(RHSV)) - return new ICmpInst(ICmpInst::ICMP_NE, LHSI, - Constant::getNullValue(RHS->getType())); - } - } + // icmp (X-Y), (X-Z) -> icmp Z, Y for equalities or if there is no overflow. + if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem && + // Try not to increase register pressure. + BO0->hasOneUse() && BO1->hasOneUse()) + return new ICmpInst(Pred, D, B); - // Try to optimize things like "A[i]&42 == 0" to index computations. - if (LoadInst *LI = dyn_cast<LoadInst>(LHSI->getOperand(0))) { - if (GetElementPtrInst *GEP = - dyn_cast<GetElementPtrInst>(LI->getOperand(0))) - if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) - if (GV->isConstant() && GV->hasDefinitiveInitializer() && - !LI->isVolatile() && isa<ConstantInt>(LHSI->getOperand(1))) { - ConstantInt *C = cast<ConstantInt>(LHSI->getOperand(1)); - if (Instruction *Res = FoldCmpLoadFromIndexedGlobal(GEP, GV,ICI, C)) - return Res; - } + // icmp (0-X) < cst --> x > -cst + if (NoOp0WrapProblem && ICmpInst::isSigned(Pred)) { + Value *X; + if (match(BO0, m_Neg(m_Value(X)))) + if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) + if (!RHSC->isMinValue(/*isSigned=*/true)) + return new ICmpInst(I.getSwappedPredicate(), X, + ConstantExpr::getNeg(RHSC)); + } + + BinaryOperator *SRem = nullptr; + // icmp (srem X, Y), Y + if (BO0 && BO0->getOpcode() == Instruction::SRem && Op1 == BO0->getOperand(1)) + SRem = BO0; + // icmp Y, (srem X, Y) + else if (BO1 && BO1->getOpcode() == Instruction::SRem && + Op0 == BO1->getOperand(1)) + SRem = BO1; + if (SRem) { + // We don't check hasOneUse to avoid increasing register pressure because + // the value we use is the same value this instruction was already using. + switch (SRem == BO0 ? ICmpInst::getSwappedPredicate(Pred) : Pred) { + default: + break; + case ICmpInst::ICMP_EQ: + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + case ICmpInst::ICMP_NE: + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + return new ICmpInst(ICmpInst::ICMP_SGT, SRem->getOperand(1), + Constant::getAllOnesValue(SRem->getType())); + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + return new ICmpInst(ICmpInst::ICMP_SLT, SRem->getOperand(1), + Constant::getNullValue(SRem->getType())); } + } - // X & -C == -C -> X > u ~C - // X & -C != -C -> X <= u ~C - // iff C is a power of 2 - if (ICI.isEquality() && RHS == LHSI->getOperand(1) && (-RHSV).isPowerOf2()) - return new ICmpInst( - ICI.getPredicate() == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_UGT - : ICmpInst::ICMP_ULE, - LHSI->getOperand(0), SubOne(RHS)); + if (BO0 && BO1 && BO0->getOpcode() == BO1->getOpcode() && BO0->hasOneUse() && + BO1->hasOneUse() && BO0->getOperand(1) == BO1->getOperand(1)) { + switch (BO0->getOpcode()) { + default: + break; + case Instruction::Add: + case Instruction::Sub: + case Instruction::Xor: + if (I.isEquality()) // a+x icmp eq/ne b+x --> a icmp b + return new ICmpInst(I.getPredicate(), BO0->getOperand(0), + BO1->getOperand(0)); + // icmp u/s (a ^ signbit), (b ^ signbit) --> icmp s/u a, b + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO0->getOperand(1))) { + if (CI->getValue().isSignBit()) { + ICmpInst::Predicate Pred = + I.isSigned() ? I.getUnsignedPredicate() : I.getSignedPredicate(); + return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); + } - // (icmp eq (and %A, C), 0) -> (icmp sgt (trunc %A), -1) - // iff C is a power of 2 - if (ICI.isEquality() && LHSI->hasOneUse() && match(RHS, m_Zero())) { - if (auto *CI = dyn_cast<ConstantInt>(LHSI->getOperand(1))) { - const APInt &AI = CI->getValue(); - int32_t ExactLogBase2 = AI.exactLogBase2(); - if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) { - Type *NTy = IntegerType::get(ICI.getContext(), ExactLogBase2 + 1); - Value *Trunc = Builder->CreateTrunc(LHSI->getOperand(0), NTy); - return new ICmpInst(ICI.getPredicate() == ICmpInst::ICMP_EQ - ? ICmpInst::ICMP_SGE - : ICmpInst::ICMP_SLT, - Trunc, Constant::getNullValue(NTy)); + if (BO0->getOpcode() == Instruction::Xor && CI->isMaxValue(true)) { + ICmpInst::Predicate Pred = + I.isSigned() ? I.getUnsignedPredicate() : I.getSignedPredicate(); + Pred = I.getSwappedPredicate(Pred); + return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); } } - } - break; - - case Instruction::Or: { - if (RHS->isOne()) { - // icmp slt signum(V) 1 --> icmp slt V, 1 - Value *V = nullptr; - if (ICI.getPredicate() == ICmpInst::ICMP_SLT && - match(LHSI, m_Signum(m_Value(V)))) - return new ICmpInst(ICmpInst::ICMP_SLT, V, - ConstantInt::get(V->getType(), 1)); - } + break; + case Instruction::Mul: + if (!I.isEquality()) + break; - if (!ICI.isEquality() || !RHS->isNullValue() || !LHSI->hasOneUse()) + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO0->getOperand(1))) { + // a * Cst icmp eq/ne b * Cst --> a & Mask icmp b & Mask + // Mask = -1 >> count-trailing-zeros(Cst). + if (!CI->isZero() && !CI->isOne()) { + const APInt &AP = CI->getValue(); + ConstantInt *Mask = ConstantInt::get( + I.getContext(), + APInt::getLowBitsSet(AP.getBitWidth(), + AP.getBitWidth() - AP.countTrailingZeros())); + Value *And1 = Builder->CreateAnd(BO0->getOperand(0), Mask); + Value *And2 = Builder->CreateAnd(BO1->getOperand(0), Mask); + return new ICmpInst(I.getPredicate(), And1, And2); + } + } break; - Value *P, *Q; - if (match(LHSI, m_Or(m_PtrToInt(m_Value(P)), m_PtrToInt(m_Value(Q))))) { - // Simplify icmp eq (or (ptrtoint P), (ptrtoint Q)), 0 - // -> and (icmp eq P, null), (icmp eq Q, null). - Value *ICIP = Builder->CreateICmp(ICI.getPredicate(), P, - Constant::getNullValue(P->getType())); - Value *ICIQ = Builder->CreateICmp(ICI.getPredicate(), Q, - Constant::getNullValue(Q->getType())); - Instruction *Op; - if (ICI.getPredicate() == ICmpInst::ICMP_EQ) - Op = BinaryOperator::CreateAnd(ICIP, ICIQ); - else - Op = BinaryOperator::CreateOr(ICIP, ICIQ); - return Op; + case Instruction::UDiv: + case Instruction::LShr: + if (I.isSigned()) + break; + LLVM_FALLTHROUGH; + case Instruction::SDiv: + case Instruction::AShr: + if (!BO0->isExact() || !BO1->isExact()) + break; + return new ICmpInst(I.getPredicate(), BO0->getOperand(0), + BO1->getOperand(0)); + case Instruction::Shl: { + bool NUW = BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap(); + bool NSW = BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap(); + if (!NUW && !NSW) + break; + if (!NSW && I.isSigned()) + break; + return new ICmpInst(I.getPredicate(), BO0->getOperand(0), + BO1->getOperand(0)); + } } - break; } - case Instruction::Mul: { // (icmp pred (mul X, Val), CI) - ConstantInt *Val = dyn_cast<ConstantInt>(LHSI->getOperand(1)); - if (!Val) break; - - // If this is a signed comparison to 0 and the mul is sign preserving, - // use the mul LHS operand instead. - ICmpInst::Predicate pred = ICI.getPredicate(); - if (isSignTest(pred, RHS) && !Val->isZero() && - cast<BinaryOperator>(LHSI)->hasNoSignedWrap()) - return new ICmpInst(Val->isNegative() ? - ICmpInst::getSwappedPredicate(pred) : pred, - LHSI->getOperand(0), - Constant::getNullValue(RHS->getType())); + if (BO0) { + // Transform A & (L - 1) `ult` L --> L != 0 + auto LSubOne = m_Add(m_Specific(Op1), m_AllOnes()); + auto BitwiseAnd = + m_CombineOr(m_And(m_Value(), LSubOne), m_And(LSubOne, m_Value())); - break; + if (match(BO0, BitwiseAnd) && I.getPredicate() == ICmpInst::ICMP_ULT) { + auto *Zero = Constant::getNullValue(BO0->getType()); + return new ICmpInst(ICmpInst::ICMP_NE, Op1, Zero); + } } - case Instruction::Shl: { // (icmp pred (shl X, ShAmt), CI) - uint32_t TypeBits = RHSV.getBitWidth(); - ConstantInt *ShAmt = dyn_cast<ConstantInt>(LHSI->getOperand(1)); - if (!ShAmt) { - Value *X; - // (1 << X) pred P2 -> X pred Log2(P2) - if (match(LHSI, m_Shl(m_One(), m_Value(X)))) { - bool RHSVIsPowerOf2 = RHSV.isPowerOf2(); - ICmpInst::Predicate Pred = ICI.getPredicate(); - if (ICI.isUnsigned()) { - if (!RHSVIsPowerOf2) { - // (1 << X) < 30 -> X <= 4 - // (1 << X) <= 30 -> X <= 4 - // (1 << X) >= 30 -> X > 4 - // (1 << X) > 30 -> X > 4 - if (Pred == ICmpInst::ICMP_ULT) - Pred = ICmpInst::ICMP_ULE; - else if (Pred == ICmpInst::ICMP_UGE) - Pred = ICmpInst::ICMP_UGT; - } - unsigned RHSLog2 = RHSV.logBase2(); + return nullptr; +} - // (1 << X) >= 2147483648 -> X >= 31 -> X == 31 - // (1 << X) < 2147483648 -> X < 31 -> X != 31 - if (RHSLog2 == TypeBits-1) { - if (Pred == ICmpInst::ICMP_UGE) - Pred = ICmpInst::ICMP_EQ; - else if (Pred == ICmpInst::ICMP_ULT) - Pred = ICmpInst::ICMP_NE; - } +/// Fold icmp Pred min|max(X, Y), X. +static Instruction *foldICmpWithMinMax(ICmpInst &Cmp) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *Op0 = Cmp.getOperand(0); + Value *X = Cmp.getOperand(1); - return new ICmpInst(Pred, X, - ConstantInt::get(RHS->getType(), RHSLog2)); - } else if (ICI.isSigned()) { - if (RHSV.isAllOnesValue()) { - // (1 << X) <= -1 -> X == 31 - if (Pred == ICmpInst::ICMP_SLE) - return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(RHS->getType(), TypeBits-1)); + // Canonicalize minimum or maximum operand to LHS of the icmp. + if (match(X, m_c_SMin(m_Specific(Op0), m_Value())) || + match(X, m_c_SMax(m_Specific(Op0), m_Value())) || + match(X, m_c_UMin(m_Specific(Op0), m_Value())) || + match(X, m_c_UMax(m_Specific(Op0), m_Value()))) { + std::swap(Op0, X); + Pred = Cmp.getSwappedPredicate(); + } - // (1 << X) > -1 -> X != 31 - if (Pred == ICmpInst::ICMP_SGT) - return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(RHS->getType(), TypeBits-1)); - } else if (!RHSV) { - // (1 << X) < 0 -> X == 31 - // (1 << X) <= 0 -> X == 31 - if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) - return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(RHS->getType(), TypeBits-1)); + Value *Y; + if (match(Op0, m_c_SMin(m_Specific(X), m_Value(Y)))) { + // smin(X, Y) == X --> X s<= Y + // smin(X, Y) s>= X --> X s<= Y + if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_SGE) + return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); - // (1 << X) >= 0 -> X != 31 - // (1 << X) > 0 -> X != 31 - if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) - return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(RHS->getType(), TypeBits-1)); - } - } else if (ICI.isEquality()) { - if (RHSVIsPowerOf2) - return new ICmpInst( - Pred, X, ConstantInt::get(RHS->getType(), RHSV.logBase2())); - } - } - break; - } + // smin(X, Y) != X --> X s> Y + // smin(X, Y) s< X --> X s> Y + if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_SLT) + return new ICmpInst(ICmpInst::ICMP_SGT, X, Y); - // Check that the shift amount is in range. If not, don't perform - // undefined shifts. When the shift is visited it will be - // simplified. - if (ShAmt->uge(TypeBits)) - break; + // These cases should be handled in InstSimplify: + // smin(X, Y) s<= X --> true + // smin(X, Y) s> X --> false + return nullptr; + } - if (ICI.isEquality()) { - // If we are comparing against bits always shifted out, the - // comparison cannot succeed. - Constant *Comp = - ConstantExpr::getShl(ConstantExpr::getLShr(RHS, ShAmt), - ShAmt); - if (Comp != RHS) {// Comparing against a bit that we know is zero. - bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; - Constant *Cst = Builder->getInt1(IsICMP_NE); - return replaceInstUsesWith(ICI, Cst); - } + if (match(Op0, m_c_SMax(m_Specific(X), m_Value(Y)))) { + // smax(X, Y) == X --> X s>= Y + // smax(X, Y) s<= X --> X s>= Y + if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_SLE) + return new ICmpInst(ICmpInst::ICMP_SGE, X, Y); - // If the shift is NUW, then it is just shifting out zeros, no need for an - // AND. - if (cast<BinaryOperator>(LHSI)->hasNoUnsignedWrap()) - return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), - ConstantExpr::getLShr(RHS, ShAmt)); + // smax(X, Y) != X --> X s< Y + // smax(X, Y) s> X --> X s< Y + if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_SGT) + return new ICmpInst(ICmpInst::ICMP_SLT, X, Y); - // If the shift is NSW and we compare to 0, then it is just shifting out - // sign bits, no need for an AND either. - if (cast<BinaryOperator>(LHSI)->hasNoSignedWrap() && RHSV == 0) - return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), - ConstantExpr::getLShr(RHS, ShAmt)); + // These cases should be handled in InstSimplify: + // smax(X, Y) s>= X --> true + // smax(X, Y) s< X --> false + return nullptr; + } - if (LHSI->hasOneUse()) { - // Otherwise strength reduce the shift into an and. - uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); - Constant *Mask = Builder->getInt(APInt::getLowBitsSet(TypeBits, - TypeBits - ShAmtVal)); + if (match(Op0, m_c_UMin(m_Specific(X), m_Value(Y)))) { + // umin(X, Y) == X --> X u<= Y + // umin(X, Y) u>= X --> X u<= Y + if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_UGE) + return new ICmpInst(ICmpInst::ICMP_ULE, X, Y); - Value *And = - Builder->CreateAnd(LHSI->getOperand(0),Mask, LHSI->getName()+".mask"); - return new ICmpInst(ICI.getPredicate(), And, - ConstantExpr::getLShr(RHS, ShAmt)); - } - } + // umin(X, Y) != X --> X u> Y + // umin(X, Y) u< X --> X u> Y + if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT) + return new ICmpInst(ICmpInst::ICMP_UGT, X, Y); - // If this is a signed comparison to 0 and the shift is sign preserving, - // use the shift LHS operand instead. - ICmpInst::Predicate pred = ICI.getPredicate(); - if (isSignTest(pred, RHS) && - cast<BinaryOperator>(LHSI)->hasNoSignedWrap()) - return new ICmpInst(pred, - LHSI->getOperand(0), - Constant::getNullValue(RHS->getType())); + // These cases should be handled in InstSimplify: + // umin(X, Y) u<= X --> true + // umin(X, Y) u> X --> false + return nullptr; + } - // Otherwise, if this is a comparison of the sign bit, simplify to and/test. - bool TrueIfSigned = false; - if (LHSI->hasOneUse() && - isSignBitCheck(ICI.getPredicate(), RHS, TrueIfSigned)) { - // (X << 31) <s 0 --> (X&1) != 0 - Constant *Mask = ConstantInt::get(LHSI->getOperand(0)->getType(), - APInt::getOneBitSet(TypeBits, - TypeBits-ShAmt->getZExtValue()-1)); - Value *And = - Builder->CreateAnd(LHSI->getOperand(0), Mask, LHSI->getName()+".mask"); - return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ, - And, Constant::getNullValue(And->getType())); - } + if (match(Op0, m_c_UMax(m_Specific(X), m_Value(Y)))) { + // umax(X, Y) == X --> X u>= Y + // umax(X, Y) u<= X --> X u>= Y + if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_ULE) + return new ICmpInst(ICmpInst::ICMP_UGE, X, Y); - // Transform (icmp pred iM (shl iM %v, N), CI) - // -> (icmp pred i(M-N) (trunc %v iM to i(M-N)), (trunc (CI>>N)) - // Transform the shl to a trunc if (trunc (CI>>N)) has no loss and M-N. - // This enables to get rid of the shift in favor of a trunc which can be - // free on the target. It has the additional benefit of comparing to a - // smaller constant, which will be target friendly. - unsigned Amt = ShAmt->getLimitedValue(TypeBits-1); - if (LHSI->hasOneUse() && - Amt != 0 && RHSV.countTrailingZeros() >= Amt) { - Type *NTy = IntegerType::get(ICI.getContext(), TypeBits - Amt); - Constant *NCI = ConstantExpr::getTrunc( - ConstantExpr::getAShr(RHS, - ConstantInt::get(RHS->getType(), Amt)), - NTy); - return new ICmpInst(ICI.getPredicate(), - Builder->CreateTrunc(LHSI->getOperand(0), NTy), - NCI); - } + // umax(X, Y) != X --> X u< Y + // umax(X, Y) u> X --> X u< Y + if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_ULT, X, Y); - break; + // These cases should be handled in InstSimplify: + // umax(X, Y) u>= X --> true + // umax(X, Y) u< X --> false + return nullptr; } - case Instruction::LShr: // (icmp pred (shr X, ShAmt), CI) - case Instruction::AShr: { - // Handle equality comparisons of shift-by-constant. - BinaryOperator *BO = cast<BinaryOperator>(LHSI); - if (ConstantInt *ShAmt = dyn_cast<ConstantInt>(LHSI->getOperand(1))) { - if (Instruction *Res = FoldICmpShrCst(ICI, BO, ShAmt)) - return Res; - } + return nullptr; +} - // Handle exact shr's. - if (ICI.isEquality() && BO->isExact() && BO->hasOneUse()) { - if (RHSV.isMinValue()) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), RHS); - } - break; - } +Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { + if (!I.isEquality()) + return nullptr; - case Instruction::UDiv: - if (ConstantInt *DivLHS = dyn_cast<ConstantInt>(LHSI->getOperand(0))) { - Value *X = LHSI->getOperand(1); - const APInt &C1 = RHS->getValue(); - const APInt &C2 = DivLHS->getValue(); - assert(C2 != 0 && "udiv 0, X should have been simplified already."); - // (icmp ugt (udiv C2, X), C1) -> (icmp ule X, C2/(C1+1)) - if (ICI.getPredicate() == ICmpInst::ICMP_UGT) { - assert(!C1.isMaxValue() && - "icmp ugt X, UINT_MAX should have been simplified already."); - return new ICmpInst(ICmpInst::ICMP_ULE, X, - ConstantInt::get(X->getType(), C2.udiv(C1 + 1))); - } - // (icmp ult (udiv C2, X), C1) -> (icmp ugt X, C2/C1) - if (ICI.getPredicate() == ICmpInst::ICMP_ULT) { - assert(C1 != 0 && "icmp ult X, 0 should have been simplified already."); - return new ICmpInst(ICmpInst::ICMP_UGT, X, - ConstantInt::get(X->getType(), C2.udiv(C1))); - } + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *A, *B, *C, *D; + if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) { + if (A == Op1 || B == Op1) { // (A^B) == A -> B == 0 + Value *OtherVal = A == Op1 ? B : A; + return new ICmpInst(I.getPredicate(), OtherVal, + Constant::getNullValue(A->getType())); } - // fall-through - case Instruction::SDiv: - // Fold: icmp pred ([us]div X, C1), C2 -> range test - // Fold this div into the comparison, producing a range check. - // Determine, based on the divide type, what the range is being - // checked. If there is an overflow on the low or high side, remember - // it, otherwise compute the range [low, hi) bounding the new value. - // See: InsertRangeTest above for the kinds of replacements possible. - if (ConstantInt *DivRHS = dyn_cast<ConstantInt>(LHSI->getOperand(1))) - if (Instruction *R = FoldICmpDivCst(ICI, cast<BinaryOperator>(LHSI), - DivRHS)) - return R; - break; - case Instruction::Sub: { - ConstantInt *LHSC = dyn_cast<ConstantInt>(LHSI->getOperand(0)); - if (!LHSC) break; - const APInt &LHSV = LHSC->getValue(); - - // C1-X <u C2 -> (X|(C2-1)) == C1 - // iff C1 & (C2-1) == C2-1 - // C2 is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_ULT && LHSI->hasOneUse() && - RHSV.isPowerOf2() && (LHSV & (RHSV - 1)) == (RHSV - 1)) - return new ICmpInst(ICmpInst::ICMP_EQ, - Builder->CreateOr(LHSI->getOperand(1), RHSV - 1), - LHSC); + if (match(Op1, m_Xor(m_Value(C), m_Value(D)))) { + // A^c1 == C^c2 --> A == C^(c1^c2) + ConstantInt *C1, *C2; + if (match(B, m_ConstantInt(C1)) && match(D, m_ConstantInt(C2)) && + Op1->hasOneUse()) { + Constant *NC = Builder->getInt(C1->getValue() ^ C2->getValue()); + Value *Xor = Builder->CreateXor(C, NC); + return new ICmpInst(I.getPredicate(), A, Xor); + } - // C1-X >u C2 -> (X|C2) != C1 - // iff C1 & C2 == C2 - // C2+1 is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_UGT && LHSI->hasOneUse() && - (RHSV + 1).isPowerOf2() && (LHSV & RHSV) == RHSV) - return new ICmpInst(ICmpInst::ICMP_NE, - Builder->CreateOr(LHSI->getOperand(1), RHSV), LHSC); - break; + // A^B == A^D -> B == D + if (A == C) + return new ICmpInst(I.getPredicate(), B, D); + if (A == D) + return new ICmpInst(I.getPredicate(), B, C); + if (B == C) + return new ICmpInst(I.getPredicate(), A, D); + if (B == D) + return new ICmpInst(I.getPredicate(), A, C); + } } - case Instruction::Add: - // Fold: icmp pred (add X, C1), C2 - if (!ICI.isEquality()) { - ConstantInt *LHSC = dyn_cast<ConstantInt>(LHSI->getOperand(1)); - if (!LHSC) break; - const APInt &LHSV = LHSC->getValue(); - - ConstantRange CR = ICI.makeConstantRange(ICI.getPredicate(), RHSV) - .subtract(LHSV); + if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && (A == Op0 || B == Op0)) { + // A == (A^B) -> B == 0 + Value *OtherVal = A == Op0 ? B : A; + return new ICmpInst(I.getPredicate(), OtherVal, + Constant::getNullValue(A->getType())); + } - if (ICI.isSigned()) { - if (CR.getLower().isSignBit()) { - return new ICmpInst(ICmpInst::ICMP_SLT, LHSI->getOperand(0), - Builder->getInt(CR.getUpper())); - } else if (CR.getUpper().isSignBit()) { - return new ICmpInst(ICmpInst::ICMP_SGE, LHSI->getOperand(0), - Builder->getInt(CR.getLower())); - } - } else { - if (CR.getLower().isMinValue()) { - return new ICmpInst(ICmpInst::ICMP_ULT, LHSI->getOperand(0), - Builder->getInt(CR.getUpper())); - } else if (CR.getUpper().isMinValue()) { - return new ICmpInst(ICmpInst::ICMP_UGE, LHSI->getOperand(0), - Builder->getInt(CR.getLower())); - } - } + // (X&Z) == (Y&Z) -> (X^Y) & Z == 0 + if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B)))) && + match(Op1, m_OneUse(m_And(m_Value(C), m_Value(D))))) { + Value *X = nullptr, *Y = nullptr, *Z = nullptr; - // X-C1 <u C2 -> (X & -C2) == C1 - // iff C1 & (C2-1) == 0 - // C2 is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_ULT && LHSI->hasOneUse() && - RHSV.isPowerOf2() && (LHSV & (RHSV - 1)) == 0) - return new ICmpInst(ICmpInst::ICMP_EQ, - Builder->CreateAnd(LHSI->getOperand(0), -RHSV), - ConstantExpr::getNeg(LHSC)); + if (A == C) { + X = B; + Y = D; + Z = A; + } else if (A == D) { + X = B; + Y = C; + Z = A; + } else if (B == C) { + X = A; + Y = D; + Z = B; + } else if (B == D) { + X = A; + Y = C; + Z = B; + } - // X-C1 >u C2 -> (X & ~C2) != C1 - // iff C1 & C2 == 0 - // C2+1 is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_UGT && LHSI->hasOneUse() && - (RHSV + 1).isPowerOf2() && (LHSV & RHSV) == 0) - return new ICmpInst(ICmpInst::ICMP_NE, - Builder->CreateAnd(LHSI->getOperand(0), ~RHSV), - ConstantExpr::getNeg(LHSC)); + if (X) { // Build (X^Y) & Z + Op1 = Builder->CreateXor(X, Y); + Op1 = Builder->CreateAnd(Op1, Z); + I.setOperand(0, Op1); + I.setOperand(1, Constant::getNullValue(Op1->getType())); + return &I; } - break; } - // Simplify icmp_eq and icmp_ne instructions with integer constant RHS. - if (ICI.isEquality()) { - bool isICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; - - // If the first operand is (add|sub|and|or|xor|rem) with a constant, and - // the second operand is a constant, simplify a bit. - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(LHSI)) { - switch (BO->getOpcode()) { - case Instruction::SRem: - // If we have a signed (X % (2^c)) == 0, turn it into an unsigned one. - if (RHSV == 0 && isa<ConstantInt>(BO->getOperand(1)) &&BO->hasOneUse()){ - const APInt &V = cast<ConstantInt>(BO->getOperand(1))->getValue(); - if (V.sgt(1) && V.isPowerOf2()) { - Value *NewRem = - Builder->CreateURem(BO->getOperand(0), BO->getOperand(1), - BO->getName()); - return new ICmpInst(ICI.getPredicate(), NewRem, - Constant::getNullValue(BO->getType())); - } - } - break; - case Instruction::Add: - // Replace ((add A, B) != C) with (A != C-B) if B & C are constants. - if (ConstantInt *BOp1C = dyn_cast<ConstantInt>(BO->getOperand(1))) { - if (BO->hasOneUse()) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - ConstantExpr::getSub(RHS, BOp1C)); - } else if (RHSV == 0) { - // Replace ((add A, B) != 0) with (A != -B) if A or B is - // efficiently invertible, or if the add has just this one use. - Value *BOp0 = BO->getOperand(0), *BOp1 = BO->getOperand(1); - - if (Value *NegVal = dyn_castNegVal(BOp1)) - return new ICmpInst(ICI.getPredicate(), BOp0, NegVal); - if (Value *NegVal = dyn_castNegVal(BOp0)) - return new ICmpInst(ICI.getPredicate(), NegVal, BOp1); - if (BO->hasOneUse()) { - Value *Neg = Builder->CreateNeg(BOp1); - Neg->takeName(BO); - return new ICmpInst(ICI.getPredicate(), BOp0, Neg); - } - } - break; - case Instruction::Xor: - if (BO->hasOneUse()) { - if (Constant *BOC = dyn_cast<Constant>(BO->getOperand(1))) { - // For the xor case, we can xor two constants together, eliminating - // the explicit xor. - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - ConstantExpr::getXor(RHS, BOC)); - } else if (RHSV == 0) { - // Replace ((xor A, B) != 0) with (A != B) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - BO->getOperand(1)); - } - } - break; - case Instruction::Sub: - if (BO->hasOneUse()) { - if (ConstantInt *BOp0C = dyn_cast<ConstantInt>(BO->getOperand(0))) { - // Replace ((sub A, B) != C) with (B != A-C) if A & C are constants. - return new ICmpInst(ICI.getPredicate(), BO->getOperand(1), - ConstantExpr::getSub(BOp0C, RHS)); - } else if (RHSV == 0) { - // Replace ((sub A, B) != 0) with (A != B) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - BO->getOperand(1)); - } - } - break; - case Instruction::Or: - // If bits are being or'd in that are not present in the constant we - // are comparing against, then the comparison could never succeed! - if (ConstantInt *BOC = dyn_cast<ConstantInt>(BO->getOperand(1))) { - Constant *NotCI = ConstantExpr::getNot(RHS); - if (!ConstantExpr::getAnd(BOC, NotCI)->isNullValue()) - return replaceInstUsesWith(ICI, Builder->getInt1(isICMP_NE)); + // Transform (zext A) == (B & (1<<X)-1) --> A == (trunc B) + // and (B & (1<<X)-1) == (zext A) --> A == (trunc B) + ConstantInt *Cst1; + if ((Op0->hasOneUse() && match(Op0, m_ZExt(m_Value(A))) && + match(Op1, m_And(m_Value(B), m_ConstantInt(Cst1)))) || + (Op1->hasOneUse() && match(Op0, m_And(m_Value(B), m_ConstantInt(Cst1))) && + match(Op1, m_ZExt(m_Value(A))))) { + APInt Pow2 = Cst1->getValue() + 1; + if (Pow2.isPowerOf2() && isa<IntegerType>(A->getType()) && + Pow2.logBase2() == cast<IntegerType>(A->getType())->getBitWidth()) + return new ICmpInst(I.getPredicate(), A, + Builder->CreateTrunc(B, A->getType())); + } - // Comparing if all bits outside of a constant mask are set? - // Replace (X | C) == -1 with (X & ~C) == ~C. - // This removes the -1 constant. - if (BO->hasOneUse() && RHS->isAllOnesValue()) { - Constant *NotBOC = ConstantExpr::getNot(BOC); - Value *And = Builder->CreateAnd(BO->getOperand(0), NotBOC); - return new ICmpInst(ICI.getPredicate(), And, NotBOC); - } - } - break; + // (A >> C) == (B >> C) --> (A^B) u< (1 << C) + // For lshr and ashr pairs. + if ((match(Op0, m_OneUse(m_LShr(m_Value(A), m_ConstantInt(Cst1)))) && + match(Op1, m_OneUse(m_LShr(m_Value(B), m_Specific(Cst1))))) || + (match(Op0, m_OneUse(m_AShr(m_Value(A), m_ConstantInt(Cst1)))) && + match(Op1, m_OneUse(m_AShr(m_Value(B), m_Specific(Cst1)))))) { + unsigned TypeBits = Cst1->getBitWidth(); + unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits); + if (ShAmt < TypeBits && ShAmt != 0) { + ICmpInst::Predicate Pred = I.getPredicate() == ICmpInst::ICMP_NE + ? ICmpInst::ICMP_UGE + : ICmpInst::ICMP_ULT; + Value *Xor = Builder->CreateXor(A, B, I.getName() + ".unshifted"); + APInt CmpVal = APInt::getOneBitSet(TypeBits, ShAmt); + return new ICmpInst(Pred, Xor, Builder->getInt(CmpVal)); + } + } - case Instruction::And: - if (ConstantInt *BOC = dyn_cast<ConstantInt>(BO->getOperand(1))) { - // If bits are being compared against that are and'd out, then the - // comparison can never succeed! - if ((RHSV & ~BOC->getValue()) != 0) - return replaceInstUsesWith(ICI, Builder->getInt1(isICMP_NE)); + // (A << C) == (B << C) --> ((A^B) & (~0U >> C)) == 0 + if (match(Op0, m_OneUse(m_Shl(m_Value(A), m_ConstantInt(Cst1)))) && + match(Op1, m_OneUse(m_Shl(m_Value(B), m_Specific(Cst1))))) { + unsigned TypeBits = Cst1->getBitWidth(); + unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits); + if (ShAmt < TypeBits && ShAmt != 0) { + Value *Xor = Builder->CreateXor(A, B, I.getName() + ".unshifted"); + APInt AndVal = APInt::getLowBitsSet(TypeBits, TypeBits - ShAmt); + Value *And = Builder->CreateAnd(Xor, Builder->getInt(AndVal), + I.getName() + ".mask"); + return new ICmpInst(I.getPredicate(), And, + Constant::getNullValue(Cst1->getType())); + } + } - // If we have ((X & C) == C), turn it into ((X & C) != 0). - if (RHS == BOC && RHSV.isPowerOf2()) - return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : - ICmpInst::ICMP_NE, LHSI, - Constant::getNullValue(RHS->getType())); + // Transform "icmp eq (trunc (lshr(X, cst1)), cst" to + // "icmp (and X, mask), cst" + uint64_t ShAmt = 0; + if (Op0->hasOneUse() && + match(Op0, m_Trunc(m_OneUse(m_LShr(m_Value(A), m_ConstantInt(ShAmt))))) && + match(Op1, m_ConstantInt(Cst1)) && + // Only do this when A has multiple uses. This is most important to do + // when it exposes other optimizations. + !A->hasOneUse()) { + unsigned ASize = cast<IntegerType>(A->getType())->getPrimitiveSizeInBits(); - // Don't perform the following transforms if the AND has multiple uses - if (!BO->hasOneUse()) - break; + if (ShAmt < ASize) { + APInt MaskV = + APInt::getLowBitsSet(ASize, Op0->getType()->getPrimitiveSizeInBits()); + MaskV <<= ShAmt; - // Replace (and X, (1 << size(X)-1) != 0) with x s< 0 - if (BOC->getValue().isSignBit()) { - Value *X = BO->getOperand(0); - Constant *Zero = Constant::getNullValue(X->getType()); - ICmpInst::Predicate pred = isICMP_NE ? - ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; - return new ICmpInst(pred, X, Zero); - } + APInt CmpV = Cst1->getValue().zext(ASize); + CmpV <<= ShAmt; - // ((X & ~7) == 0) --> X < 8 - if (RHSV == 0 && isHighOnes(BOC)) { - Value *X = BO->getOperand(0); - Constant *NegX = ConstantExpr::getNeg(BOC); - ICmpInst::Predicate pred = isICMP_NE ? - ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; - return new ICmpInst(pred, X, NegX); - } - } - break; - case Instruction::Mul: - if (RHSV == 0 && BO->hasNoSignedWrap()) { - if (ConstantInt *BOC = dyn_cast<ConstantInt>(BO->getOperand(1))) { - // The trivial case (mul X, 0) is handled by InstSimplify - // General case : (mul X, C) != 0 iff X != 0 - // (mul X, C) == 0 iff X == 0 - if (!BOC->isZero()) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - Constant::getNullValue(RHS->getType())); - } - } - break; - default: break; - } - } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(LHSI)) { - // Handle icmp {eq|ne} <intrinsic>, intcst. - switch (II->getIntrinsicID()) { - case Intrinsic::bswap: - Worklist.Add(II); - ICI.setOperand(0, II->getArgOperand(0)); - ICI.setOperand(1, Builder->getInt(RHSV.byteSwap())); - return &ICI; - case Intrinsic::ctlz: - case Intrinsic::cttz: - // ctz(A) == bitwidth(a) -> A == 0 and likewise for != - if (RHSV == RHS->getType()->getBitWidth()) { - Worklist.Add(II); - ICI.setOperand(0, II->getArgOperand(0)); - ICI.setOperand(1, ConstantInt::get(RHS->getType(), 0)); - return &ICI; - } - break; - case Intrinsic::ctpop: - // popcount(A) == 0 -> A == 0 and likewise for != - if (RHS->isZero()) { - Worklist.Add(II); - ICI.setOperand(0, II->getArgOperand(0)); - ICI.setOperand(1, RHS); - return &ICI; - } - break; - default: - break; - } + Value *Mask = Builder->CreateAnd(A, Builder->getInt(MaskV)); + return new ICmpInst(I.getPredicate(), Mask, Builder->getInt(CmpV)); } } + return nullptr; } /// Handle icmp (cast x to y), (cast/cst). We only handle extending casts so /// far. -Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICmp) { +Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) { const CastInst *LHSCI = cast<CastInst>(ICmp.getOperand(0)); Value *LHSCIOp = LHSCI->getOperand(0); Type *SrcTy = LHSCIOp->getType(); @@ -2485,92 +3380,6 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICmp) { return BinaryOperator::CreateNot(Result); } -/// The caller has matched a pattern of the form: -/// I = icmp ugt (add (add A, B), CI2), CI1 -/// If this is of the form: -/// sum = a + b -/// if (sum+128 >u 255) -/// Then replace it with llvm.sadd.with.overflow.i8. -/// -static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, - ConstantInt *CI2, ConstantInt *CI1, - InstCombiner &IC) { - // The transformation we're trying to do here is to transform this into an - // llvm.sadd.with.overflow. To do this, we have to replace the original add - // with a narrower add, and discard the add-with-constant that is part of the - // range check (if we can't eliminate it, this isn't profitable). - - // In order to eliminate the add-with-constant, the compare can be its only - // use. - Instruction *AddWithCst = cast<Instruction>(I.getOperand(0)); - if (!AddWithCst->hasOneUse()) return nullptr; - - // If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow. - if (!CI2->getValue().isPowerOf2()) return nullptr; - unsigned NewWidth = CI2->getValue().countTrailingZeros(); - if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) return nullptr; - - // The width of the new add formed is 1 more than the bias. - ++NewWidth; - - // Check to see that CI1 is an all-ones value with NewWidth bits. - if (CI1->getBitWidth() == NewWidth || - CI1->getValue() != APInt::getLowBitsSet(CI1->getBitWidth(), NewWidth)) - return nullptr; - - // This is only really a signed overflow check if the inputs have been - // sign-extended; check for that condition. For example, if CI2 is 2^31 and - // the operands of the add are 64 bits wide, we need at least 33 sign bits. - unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1; - if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits || - IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits) - return nullptr; - - // In order to replace the original add with a narrower - // llvm.sadd.with.overflow, the only uses allowed are the add-with-constant - // and truncates that discard the high bits of the add. Verify that this is - // the case. - Instruction *OrigAdd = cast<Instruction>(AddWithCst->getOperand(0)); - for (User *U : OrigAdd->users()) { - if (U == AddWithCst) continue; - - // Only accept truncates for now. We would really like a nice recursive - // predicate like SimplifyDemandedBits, but which goes downwards the use-def - // chain to see which bits of a value are actually demanded. If the - // original add had another add which was then immediately truncated, we - // could still do the transformation. - TruncInst *TI = dyn_cast<TruncInst>(U); - if (!TI || TI->getType()->getPrimitiveSizeInBits() > NewWidth) - return nullptr; - } - - // If the pattern matches, truncate the inputs to the narrower type and - // use the sadd_with_overflow intrinsic to efficiently compute both the - // result and the overflow bit. - Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth); - Value *F = Intrinsic::getDeclaration(I.getModule(), - Intrinsic::sadd_with_overflow, NewType); - - InstCombiner::BuilderTy *Builder = IC.Builder; - - // Put the new code above the original add, in case there are any uses of the - // add between the add and the compare. - Builder->SetInsertPoint(OrigAdd); - - Value *TruncA = Builder->CreateTrunc(A, NewType, A->getName()+".trunc"); - Value *TruncB = Builder->CreateTrunc(B, NewType, B->getName()+".trunc"); - CallInst *Call = Builder->CreateCall(F, {TruncA, TruncB}, "sadd"); - Value *Add = Builder->CreateExtractValue(Call, 0, "sadd.result"); - Value *ZExt = Builder->CreateZExt(Add, OrigAdd->getType()); - - // The inner add was the result of the narrow add, zero extended to the - // wider type. Replace it with the result computed by the intrinsic. - IC.replaceInstUsesWith(*OrigAdd, ZExt); - - // The original icmp gets replaced with the overflow value. - return ExtractValueInst::Create(Call, 1, "sadd.overflow"); -} - bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, Value *RHS, Instruction &OrigI, Value *&Result, Constant *&Overflow) { @@ -2603,8 +3412,10 @@ bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, if (OR == OverflowResult::AlwaysOverflows) return SetResult(Builder->CreateAdd(LHS, RHS), Builder->getTrue(), true); + + // Fall through uadd into sadd + LLVM_FALLTHROUGH; } - // FALL THROUGH uadd into sadd case OCF_SIGNED_ADD: { // X + 0 -> {X, false} if (match(RHS, m_Zero())) @@ -2644,7 +3455,8 @@ bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, true); if (OR == OverflowResult::AlwaysOverflows) return SetResult(Builder->CreateMul(LHS, RHS), Builder->getTrue(), true); - } // FALL THROUGH + LLVM_FALLTHROUGH; + } case OCF_SIGNED_MUL: // X * undef -> undef if (isa<UndefValue>(RHS)) @@ -2682,7 +3494,7 @@ bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, /// \param OtherVal The other argument of compare instruction. /// \returns Instruction which must replace the compare instruction, NULL if no /// replacement required. -static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, +static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, Value *OtherVal, InstCombiner &IC) { // Don't bother doing this transformation for pointers, don't do it for // vectors. @@ -2906,8 +3718,8 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, /// When performing a comparison against a constant, it is possible that not all /// the bits in the LHS are demanded. This helper method computes the mask that /// IS demanded. -static APInt DemandedBitsLHSMask(ICmpInst &I, - unsigned BitWidth, bool isSignCheck) { +static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth, + bool isSignCheck) { if (isSignCheck) return APInt::getSignBit(BitWidth); @@ -2981,7 +3793,7 @@ static bool swapMayExposeCSEOpportunities(const Value * Op0, } /// \brief Check that one use is in the same block as the definition and all -/// other uses are in blocks dominated by a given block +/// other uses are in blocks dominated by a given block. /// /// \param DI Definition /// \param UI Use @@ -2994,21 +3806,18 @@ bool InstCombiner::dominatesAllUses(const Instruction *DI, const Instruction *UI, const BasicBlock *DB) const { assert(DI && UI && "Instruction not defined\n"); - // ignore incomplete definitions + // Ignore incomplete definitions. if (!DI->getParent()) return false; - // DI and UI must be in the same block + // DI and UI must be in the same block. if (DI->getParent() != UI->getParent()) return false; - // Protect from self-referencing blocks + // Protect from self-referencing blocks. if (DI->getParent() == DB) return false; - // DominatorTree available? - if (!DT) - return false; for (const User *U : DI->users()) { auto *Usr = cast<Instruction>(U); - if (Usr != UI && !DT->dominates(DB, Usr->getParent())) + if (Usr != UI && !DT.dominates(DB, Usr->getParent())) return false; } return true; @@ -3067,8 +3876,7 @@ static bool isChainSelectCmpBranch(const SelectInst *SI) { /// are equal, the optimization can work only for EQ predicates. This is not a /// major restriction since a NE compare should be 'normalized' to an equal /// compare, which usually happens in the combiner and test case -/// select-cmp-br.ll -/// checks for it. +/// select-cmp-br.ll checks for it. bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, const ICmpInst *Icmp, const unsigned SIOpd) { @@ -3076,7 +3884,7 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, if (isChainSelectCmpBranch(SI) && Icmp->getPredicate() == ICmpInst::ICMP_EQ) { BasicBlock *Succ = SI->getParent()->getTerminator()->getSuccessor(1); // The check for the unique predecessor is not the best that can be - // done. But it protects efficiently against cases like when SI's + // done. But it protects efficiently against cases like when SI's // home block has two successors, Succ and Succ1, and Succ1 predecessor // of Succ. Then SI can't be replaced by SIOpd because the use that gets // replaced can be reached on either path. So the uniqueness check @@ -3093,6 +3901,239 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, return false; } +/// Try to fold the comparison based on range information we can get by checking +/// whether bits are known to be zero or one in the inputs. +Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = Op0->getType(); + ICmpInst::Predicate Pred = I.getPredicate(); + + // Get scalar or pointer size. + unsigned BitWidth = Ty->isIntOrIntVectorTy() + ? Ty->getScalarSizeInBits() + : DL.getTypeSizeInBits(Ty->getScalarType()); + + if (!BitWidth) + return nullptr; + + // If this is a normal comparison, it demands all bits. If it is a sign bit + // comparison, it only demands the sign bit. + bool IsSignBit = false; + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + bool UnusedBit; + IsSignBit = isSignBitCheck(Pred, *CmpC, UnusedBit); + } + + APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0); + APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0); + + if (SimplifyDemandedBits(I.getOperandUse(0), + getDemandedBitsLHSMask(I, BitWidth, IsSignBit), + Op0KnownZero, Op0KnownOne, 0)) + return &I; + + if (SimplifyDemandedBits(I.getOperandUse(1), APInt::getAllOnesValue(BitWidth), + Op1KnownZero, Op1KnownOne, 0)) + return &I; + + // Given the known and unknown bits, compute a range that the LHS could be + // in. Compute the Min, Max and RHS values based on the known bits. For the + // EQ and NE we use unsigned values. + APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0); + APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0); + if (I.isSigned()) { + computeSignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, Op0Min, + Op0Max); + computeSignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, Op1Min, + Op1Max); + } else { + computeUnsignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, Op0Min, + Op0Max); + computeUnsignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, Op1Min, + Op1Max); + } + + // If Min and Max are known to be the same, then SimplifyDemandedBits + // figured out that the LHS is a constant. Constant fold this now, so that + // code below can assume that Min != Max. + if (!isa<Constant>(Op0) && Op0Min == Op0Max) + return new ICmpInst(Pred, ConstantInt::get(Op0->getType(), Op0Min), Op1); + if (!isa<Constant>(Op1) && Op1Min == Op1Max) + return new ICmpInst(Pred, Op0, ConstantInt::get(Op1->getType(), Op1Min)); + + // Based on the range information we know about the LHS, see if we can + // simplify this comparison. For example, (x&4) < 8 is always true. + switch (Pred) { + default: + llvm_unreachable("Unknown icmp opcode!"); + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: { + if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) { + return Pred == CmpInst::ICMP_EQ + ? replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())) + : replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + } + + // If all bits are known zero except for one, then we know at most one bit + // is set. If the comparison is against zero, then this is a check to see if + // *that* bit is set. + APInt Op0KnownZeroInverted = ~Op0KnownZero; + if (~Op1KnownZero == 0) { + // If the LHS is an AND with the same constant, look through it. + Value *LHS = nullptr; + const APInt *LHSC; + if (!match(Op0, m_And(m_Value(LHS), m_APInt(LHSC))) || + *LHSC != Op0KnownZeroInverted) + LHS = Op0; + + Value *X; + if (match(LHS, m_Shl(m_One(), m_Value(X)))) { + APInt ValToCheck = Op0KnownZeroInverted; + Type *XTy = X->getType(); + if (ValToCheck.isPowerOf2()) { + // ((1 << X) & 8) == 0 -> X != 3 + // ((1 << X) & 8) != 0 -> X == 3 + auto *CmpC = ConstantInt::get(XTy, ValToCheck.countTrailingZeros()); + auto NewPred = ICmpInst::getInversePredicate(Pred); + return new ICmpInst(NewPred, X, CmpC); + } else if ((++ValToCheck).isPowerOf2()) { + // ((1 << X) & 7) == 0 -> X >= 3 + // ((1 << X) & 7) != 0 -> X < 3 + auto *CmpC = ConstantInt::get(XTy, ValToCheck.countTrailingZeros()); + auto NewPred = + Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGE : CmpInst::ICMP_ULT; + return new ICmpInst(NewPred, X, CmpC); + } + } + + // Check if the LHS is 8 >>u x and the result is a power of 2 like 1. + const APInt *CI; + if (Op0KnownZeroInverted == 1 && + match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) { + // ((8 >>u X) & 1) == 0 -> X != 3 + // ((8 >>u X) & 1) != 0 -> X == 3 + unsigned CmpVal = CI->countTrailingZeros(); + auto NewPred = ICmpInst::getInversePredicate(Pred); + return new ICmpInst(NewPred, X, ConstantInt::get(X->getType(), CmpVal)); + } + } + break; + } + case ICmpInst::ICMP_ULT: { + if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + // A <u C -> A == C-1 if min(A)+1 == C + if (Op1Max == Op0Min + 1) { + Constant *CMinus1 = ConstantInt::get(Op0->getType(), *CmpC - 1); + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, CMinus1); + } + // (x <u 2147483648) -> (x >s -1) -> true if sign bit clear + if (CmpC->isMinSignedValue()) { + Constant *AllOnes = Constant::getAllOnesValue(Op0->getType()); + return new ICmpInst(ICmpInst::ICMP_SGT, Op0, AllOnes); + } + } + break; + } + case ICmpInst::ICMP_UGT: { + if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + + if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + + if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + // A >u C -> A == C+1 if max(a)-1 == C + if (*CmpC == Op0Max - 1) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + ConstantInt::get(Op1->getType(), *CmpC + 1)); + + // (x >u 2147483647) -> (x <s 0) -> true if sign bit set + if (CmpC->isMaxSignedValue()) + return new ICmpInst(ICmpInst::ICMP_SLT, Op0, + Constant::getNullValue(Op0->getType())); + } + break; + } + case ICmpInst::ICMP_SLT: + if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { + if (Op1Max == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + Builder->getInt(CI->getValue() - 1)); + } + break; + case ICmpInst::ICMP_SGT: + if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + + if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { + if (Op1Min == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + Builder->getInt(CI->getValue() + 1)); + } + break; + case ICmpInst::ICMP_SGE: + assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!"); + if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + case ICmpInst::ICMP_SLE: + assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!"); + if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + case ICmpInst::ICMP_UGE: + assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!"); + if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + case ICmpInst::ICMP_ULE: + assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!"); + if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + } + + // Turn a signed comparison into an unsigned one if both operands are known to + // have the same sign. + if (I.isSigned() && + ((Op0KnownZero.isNegative() && Op1KnownZero.isNegative()) || + (Op0KnownOne.isNegative() && Op1KnownOne.isNegative()))) + return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1); + + return nullptr; +} + /// If we have an icmp le or icmp ge instruction with a constant operand, turn /// it into the appropriate icmp lt or icmp gt instruction. This transform /// allows them to be folded in visitICmpInst. @@ -3131,6 +4172,7 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { if (isa<UndefValue>(Elt)) continue; + // Bail out if we can't determine if this constant is min/max or if we // know that this constant is min/max. auto *CI = dyn_cast<ConstantInt>(Elt); @@ -3167,7 +4209,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } if (Value *V = - SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AC, &I)) + SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL, &TLI, &DT, &AC, &I)) return replaceInstUsesWith(I, V); // comparing -val or val with non-zero is the same as just comparing val @@ -3202,28 +4244,28 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { case ICmpInst::ICMP_UGT: std::swap(Op0, Op1); // Change icmp ugt -> icmp ult - // FALL THROUGH + LLVM_FALLTHROUGH; case ICmpInst::ICMP_ULT:{ // icmp ult i1 A, B -> ~A & B Value *Not = Builder->CreateNot(Op0, I.getName() + "tmp"); return BinaryOperator::CreateAnd(Not, Op1); } case ICmpInst::ICMP_SGT: std::swap(Op0, Op1); // Change icmp sgt -> icmp slt - // FALL THROUGH + LLVM_FALLTHROUGH; case ICmpInst::ICMP_SLT: { // icmp slt i1 A, B -> A & ~B Value *Not = Builder->CreateNot(Op1, I.getName() + "tmp"); return BinaryOperator::CreateAnd(Not, Op0); } case ICmpInst::ICMP_UGE: std::swap(Op0, Op1); // Change icmp uge -> icmp ule - // FALL THROUGH + LLVM_FALLTHROUGH; case ICmpInst::ICMP_ULE: { // icmp ule i1 A, B -> ~A | B Value *Not = Builder->CreateNot(Op0, I.getName() + "tmp"); return BinaryOperator::CreateOr(Not, Op1); } case ICmpInst::ICMP_SGE: std::swap(Op0, Op1); // Change icmp sge -> icmp sle - // FALL THROUGH + LLVM_FALLTHROUGH; case ICmpInst::ICMP_SLE: { // icmp sle i1 A, B -> A | ~B Value *Not = Builder->CreateNot(Op1, I.getName() + "tmp"); return BinaryOperator::CreateOr(Not, Op0); @@ -3234,372 +4276,11 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (ICmpInst *NewICmp = canonicalizeCmpWithConstant(I)) return NewICmp; - unsigned BitWidth = 0; - if (Ty->isIntOrIntVectorTy()) - BitWidth = Ty->getScalarSizeInBits(); - else // Get pointer size. - BitWidth = DL.getTypeSizeInBits(Ty->getScalarType()); - - bool isSignBit = false; - - // See if we are doing a comparison with a constant. - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - Value *A = nullptr, *B = nullptr; - - // Match the following pattern, which is a common idiom when writing - // overflow-safe integer arithmetic function. The source performs an - // addition in wider type, and explicitly checks for overflow using - // comparisons against INT_MIN and INT_MAX. Simplify this by using the - // sadd_with_overflow intrinsic. - // - // TODO: This could probably be generalized to handle other overflow-safe - // operations if we worked out the formulas to compute the appropriate - // magic constants. - // - // sum = a + b - // if (sum+128 >u 255) ... -> llvm.sadd.with.overflow.i8 - { - ConstantInt *CI2; // I = icmp ugt (add (add A, B), CI2), CI - if (I.getPredicate() == ICmpInst::ICMP_UGT && - match(Op0, m_Add(m_Add(m_Value(A), m_Value(B)), m_ConstantInt(CI2)))) - if (Instruction *Res = ProcessUGT_ADDCST_ADD(I, A, B, CI2, CI, *this)) - return Res; - } - - // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) - if (CI->isZero() && I.getPredicate() == ICmpInst::ICMP_SGT) - if (auto *SI = dyn_cast<SelectInst>(Op0)) { - SelectPatternResult SPR = matchSelectPattern(SI, A, B); - if (SPR.Flavor == SPF_SMIN) { - if (isKnownPositive(A, DL)) - return new ICmpInst(I.getPredicate(), B, CI); - if (isKnownPositive(B, DL)) - return new ICmpInst(I.getPredicate(), A, CI); - } - } - - - // The following transforms are only 'worth it' if the only user of the - // subtraction is the icmp. - if (Op0->hasOneUse()) { - // (icmp ne/eq (sub A B) 0) -> (icmp ne/eq A, B) - if (I.isEquality() && CI->isZero() && - match(Op0, m_Sub(m_Value(A), m_Value(B)))) - return new ICmpInst(I.getPredicate(), A, B); - - // (icmp sgt (sub nsw A B), -1) -> (icmp sge A, B) - if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isAllOnesValue() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SGE, A, B); - - // (icmp sgt (sub nsw A B), 0) -> (icmp sgt A, B) - if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isZero() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SGT, A, B); - - // (icmp slt (sub nsw A B), 0) -> (icmp slt A, B) - if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isZero() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SLT, A, B); - - // (icmp slt (sub nsw A B), 1) -> (icmp sle A, B) - if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isOne() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SLE, A, B); - } - - if (I.isEquality()) { - ConstantInt *CI2; - if (match(Op0, m_AShr(m_ConstantInt(CI2), m_Value(A))) || - match(Op0, m_LShr(m_ConstantInt(CI2), m_Value(A)))) { - // (icmp eq/ne (ashr/lshr const2, A), const1) - if (Instruction *Inst = FoldICmpCstShrCst(I, Op0, A, CI, CI2)) - return Inst; - } - if (match(Op0, m_Shl(m_ConstantInt(CI2), m_Value(A)))) { - // (icmp eq/ne (shl const2, A), const1) - if (Instruction *Inst = FoldICmpCstShlCst(I, Op0, A, CI, CI2)) - return Inst; - } - } - - // If this comparison is a normal comparison, it demands all - // bits, if it is a sign bit comparison, it only demands the sign bit. - bool UnusedBit; - isSignBit = isSignBitCheck(I.getPredicate(), CI, UnusedBit); - - // Canonicalize icmp instructions based on dominating conditions. - BasicBlock *Parent = I.getParent(); - BasicBlock *Dom = Parent->getSinglePredecessor(); - auto *BI = Dom ? dyn_cast<BranchInst>(Dom->getTerminator()) : nullptr; - ICmpInst::Predicate Pred; - BasicBlock *TrueBB, *FalseBB; - ConstantInt *CI2; - if (BI && match(BI, m_Br(m_ICmp(Pred, m_Specific(Op0), m_ConstantInt(CI2)), - TrueBB, FalseBB)) && - TrueBB != FalseBB) { - ConstantRange CR = ConstantRange::makeAllowedICmpRegion(I.getPredicate(), - CI->getValue()); - ConstantRange DominatingCR = - (Parent == TrueBB) - ? ConstantRange::makeExactICmpRegion(Pred, CI2->getValue()) - : ConstantRange::makeExactICmpRegion( - CmpInst::getInversePredicate(Pred), CI2->getValue()); - ConstantRange Intersection = DominatingCR.intersectWith(CR); - ConstantRange Difference = DominatingCR.difference(CR); - if (Intersection.isEmptySet()) - return replaceInstUsesWith(I, Builder->getFalse()); - if (Difference.isEmptySet()) - return replaceInstUsesWith(I, Builder->getTrue()); - // Canonicalizing a sign bit comparison that gets used in a branch, - // pessimizes codegen by generating branch on zero instruction instead - // of a test and branch. So we avoid canonicalizing in such situations - // because test and branch instruction has better branch displacement - // than compare and branch instruction. - if (!isBranchOnSignBitCheck(I, isSignBit) && !I.isEquality()) { - if (auto *AI = Intersection.getSingleElement()) - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Builder->getInt(*AI)); - if (auto *AD = Difference.getSingleElement()) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Builder->getInt(*AD)); - } - } - } - - // See if we can fold the comparison based on range information we can get - // by checking whether bits are known to be zero or one in the input. - if (BitWidth != 0) { - APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0); - APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0); - - if (SimplifyDemandedBits(I.getOperandUse(0), - DemandedBitsLHSMask(I, BitWidth, isSignBit), - Op0KnownZero, Op0KnownOne, 0)) - return &I; - if (SimplifyDemandedBits(I.getOperandUse(1), - APInt::getAllOnesValue(BitWidth), Op1KnownZero, - Op1KnownOne, 0)) - return &I; - - // Given the known and unknown bits, compute a range that the LHS could be - // in. Compute the Min, Max and RHS values based on the known bits. For the - // EQ and NE we use unsigned values. - APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0); - APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0); - if (I.isSigned()) { - ComputeSignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, - Op0Min, Op0Max); - ComputeSignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, - Op1Min, Op1Max); - } else { - ComputeUnsignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, - Op0Min, Op0Max); - ComputeUnsignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, - Op1Min, Op1Max); - } - - // If Min and Max are known to be the same, then SimplifyDemandedBits - // figured out that the LHS is a constant. Just constant fold this now so - // that code below can assume that Min != Max. - if (!isa<Constant>(Op0) && Op0Min == Op0Max) - return new ICmpInst(I.getPredicate(), - ConstantInt::get(Op0->getType(), Op0Min), Op1); - if (!isa<Constant>(Op1) && Op1Min == Op1Max) - return new ICmpInst(I.getPredicate(), Op0, - ConstantInt::get(Op1->getType(), Op1Min)); - - // Based on the range information we know about the LHS, see if we can - // simplify this comparison. For example, (x&4) < 8 is always true. - switch (I.getPredicate()) { - default: llvm_unreachable("Unknown icmp opcode!"); - case ICmpInst::ICMP_EQ: { - if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - - // If all bits are known zero except for one, then we know at most one - // bit is set. If the comparison is against zero, then this is a check - // to see if *that* bit is set. - APInt Op0KnownZeroInverted = ~Op0KnownZero; - if (~Op1KnownZero == 0) { - // If the LHS is an AND with the same constant, look through it. - Value *LHS = nullptr; - ConstantInt *LHSC = nullptr; - if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) || - LHSC->getValue() != Op0KnownZeroInverted) - LHS = Op0; - - // If the LHS is 1 << x, and we know the result is a power of 2 like 8, - // then turn "((1 << x)&8) == 0" into "x != 3". - // or turn "((1 << x)&7) == 0" into "x > 2". - Value *X = nullptr; - if (match(LHS, m_Shl(m_One(), m_Value(X)))) { - APInt ValToCheck = Op0KnownZeroInverted; - if (ValToCheck.isPowerOf2()) { - unsigned CmpVal = ValToCheck.countTrailingZeros(); - return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(X->getType(), CmpVal)); - } else if ((++ValToCheck).isPowerOf2()) { - unsigned CmpVal = ValToCheck.countTrailingZeros() - 1; - return new ICmpInst(ICmpInst::ICMP_UGT, X, - ConstantInt::get(X->getType(), CmpVal)); - } - } - - // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1, - // then turn "((8 >>u x)&1) == 0" into "x != 3". - const APInt *CI; - if (Op0KnownZeroInverted == 1 && - match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) - return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(X->getType(), - CI->countTrailingZeros())); - } - break; - } - case ICmpInst::ICMP_NE: { - if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - - // If all bits are known zero except for one, then we know at most one - // bit is set. If the comparison is against zero, then this is a check - // to see if *that* bit is set. - APInt Op0KnownZeroInverted = ~Op0KnownZero; - if (~Op1KnownZero == 0) { - // If the LHS is an AND with the same constant, look through it. - Value *LHS = nullptr; - ConstantInt *LHSC = nullptr; - if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) || - LHSC->getValue() != Op0KnownZeroInverted) - LHS = Op0; - - // If the LHS is 1 << x, and we know the result is a power of 2 like 8, - // then turn "((1 << x)&8) != 0" into "x == 3". - // or turn "((1 << x)&7) != 0" into "x < 3". - Value *X = nullptr; - if (match(LHS, m_Shl(m_One(), m_Value(X)))) { - APInt ValToCheck = Op0KnownZeroInverted; - if (ValToCheck.isPowerOf2()) { - unsigned CmpVal = ValToCheck.countTrailingZeros(); - return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(X->getType(), CmpVal)); - } else if ((++ValToCheck).isPowerOf2()) { - unsigned CmpVal = ValToCheck.countTrailingZeros(); - return new ICmpInst(ICmpInst::ICMP_ULT, X, - ConstantInt::get(X->getType(), CmpVal)); - } - } - - // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1, - // then turn "((8 >>u x)&1) != 0" into "x == 3". - const APInt *CI; - if (Op0KnownZeroInverted == 1 && - match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) - return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(X->getType(), - CI->countTrailingZeros())); - } - break; - } - case ICmpInst::ICMP_ULT: - if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Max == Op0Min+1) // A <u C -> A == C-1 if min(A)+1 == C - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder->getInt(CI->getValue()-1)); - - // (x <u 2147483648) -> (x >s -1) -> true if sign bit clear - if (CI->isMinValue(true)) - return new ICmpInst(ICmpInst::ICMP_SGT, Op0, - Constant::getAllOnesValue(Op0->getType())); - } - break; - case ICmpInst::ICMP_UGT: - if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - - if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Min == Op0Max-1) // A >u C -> A == C+1 if max(a)-1 == C - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder->getInt(CI->getValue()+1)); - - // (x >u 2147483647) -> (x <s 0) -> true if sign bit set - if (CI->isMaxValue(true)) - return new ICmpInst(ICmpInst::ICMP_SLT, Op0, - Constant::getNullValue(Op0->getType())); - } - break; - case ICmpInst::ICMP_SLT: - if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Max == Op0Min+1) // A <s C -> A == C-1 if min(A)+1 == C - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder->getInt(CI->getValue()-1)); - } - break; - case ICmpInst::ICMP_SGT: - if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - - if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Min == Op0Max-1) // A >s C -> A == C+1 if max(A)-1 == C - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder->getInt(CI->getValue()+1)); - } - break; - case ICmpInst::ICMP_SGE: - assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!"); - if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - case ICmpInst::ICMP_SLE: - assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!"); - if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - case ICmpInst::ICMP_UGE: - assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!"); - if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - case ICmpInst::ICMP_ULE: - assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!"); - if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - } + if (Instruction *Res = foldICmpWithConstant(I)) + return Res; - // Turn a signed comparison into an unsigned one if both operands - // are known to have the same sign. - if (I.isSigned() && - ((Op0KnownZero.isNegative() && Op1KnownZero.isNegative()) || - (Op0KnownOne.isNegative() && Op1KnownOne.isNegative()))) - return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1); - } + if (Instruction *Res = foldICmpUsingKnownBits(I)) + return Res; // Test if the ICmpInst instruction is used exclusively by a select as // part of a minimum or maximum operation. If so, refrain from doing @@ -3614,122 +4295,18 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1)) return nullptr; - // See if we are doing a comparison between a constant and an instruction that - // can be folded into the comparison. - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - Value *A = nullptr, *B = nullptr; - // Since the RHS is a ConstantInt (CI), if the left hand side is an - // instruction, see if that instruction also has constants so that the - // instruction can be folded into the icmp - if (Instruction *LHSI = dyn_cast<Instruction>(Op0)) - if (Instruction *Res = visitICmpInstWithInstAndIntCst(I, LHSI, CI)) - return Res; - - // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) - if (I.isEquality() && CI->isZero() && - match(Op0, m_UDiv(m_Value(A), m_Value(B)))) { - ICmpInst::Predicate Pred = I.getPredicate() == ICmpInst::ICMP_EQ - ? ICmpInst::ICMP_UGT - : ICmpInst::ICMP_ULE; - return new ICmpInst(Pred, B, A); - } - } - - // Handle icmp with constant (but not simple integer constant) RHS - if (Constant *RHSC = dyn_cast<Constant>(Op1)) { - if (Instruction *LHSI = dyn_cast<Instruction>(Op0)) - switch (LHSI->getOpcode()) { - case Instruction::GetElementPtr: - // icmp pred GEP (P, int 0, int 0, int 0), null -> icmp pred P, null - if (RHSC->isNullValue() && - cast<GetElementPtrInst>(LHSI)->hasAllZeroIndices()) - return new ICmpInst(I.getPredicate(), LHSI->getOperand(0), - Constant::getNullValue(LHSI->getOperand(0)->getType())); - break; - case Instruction::PHI: - // Only fold icmp into the PHI if the phi and icmp are in the same - // block. If in the same block, we're encouraging jump threading. If - // not, we are just pessimizing the code by making an i1 phi. - if (LHSI->getParent() == I.getParent()) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; - break; - case Instruction::Select: { - // If either operand of the select is a constant, we can fold the - // comparison into the select arms, which will cause one to be - // constant folded and the select turned into a bitwise or. - Value *Op1 = nullptr, *Op2 = nullptr; - ConstantInt *CI = nullptr; - if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1))) { - Op1 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); - CI = dyn_cast<ConstantInt>(Op1); - } - if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(2))) { - Op2 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); - CI = dyn_cast<ConstantInt>(Op2); - } - - // We only want to perform this transformation if it will not lead to - // additional code. This is true if either both sides of the select - // fold to a constant (in which case the icmp is replaced with a select - // which will usually simplify) or this is the only user of the - // select (in which case we are trading a select+icmp for a simpler - // select+icmp) or all uses of the select can be replaced based on - // dominance information ("Global cases"). - bool Transform = false; - if (Op1 && Op2) - Transform = true; - else if (Op1 || Op2) { - // Local case - if (LHSI->hasOneUse()) - Transform = true; - // Global cases - else if (CI && !CI->isZero()) - // When Op1 is constant try replacing select with second operand. - // Otherwise Op2 is constant and try replacing select with first - // operand. - Transform = replacedSelectWithOperand(cast<SelectInst>(LHSI), &I, - Op1 ? 2 : 1); - } - if (Transform) { - if (!Op1) - Op1 = Builder->CreateICmp(I.getPredicate(), LHSI->getOperand(1), - RHSC, I.getName()); - if (!Op2) - Op2 = Builder->CreateICmp(I.getPredicate(), LHSI->getOperand(2), - RHSC, I.getName()); - return SelectInst::Create(LHSI->getOperand(0), Op1, Op2); - } - break; - } - case Instruction::IntToPtr: - // icmp pred inttoptr(X), null -> icmp pred X, 0 - if (RHSC->isNullValue() && - DL.getIntPtrType(RHSC->getType()) == LHSI->getOperand(0)->getType()) - return new ICmpInst(I.getPredicate(), LHSI->getOperand(0), - Constant::getNullValue(LHSI->getOperand(0)->getType())); - break; + if (Instruction *Res = foldICmpInstWithConstant(I)) + return Res; - case Instruction::Load: - // Try to optimize things like "A[i] > 4" to index computations. - if (GetElementPtrInst *GEP = - dyn_cast<GetElementPtrInst>(LHSI->getOperand(0))) { - if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) - if (GV->isConstant() && GV->hasDefinitiveInitializer() && - !cast<LoadInst>(LHSI)->isVolatile()) - if (Instruction *Res = FoldCmpLoadFromIndexedGlobal(GEP, GV, I)) - return Res; - } - break; - } - } + if (Instruction *Res = foldICmpInstWithConstantNotInt(I)) + return Res; // If we can optimize a 'icmp GEP, P' or 'icmp P, GEP', do so now. if (GEPOperator *GEP = dyn_cast<GEPOperator>(Op0)) - if (Instruction *NI = FoldGEPICmp(GEP, Op1, I.getPredicate(), I)) + if (Instruction *NI = foldGEPICmp(GEP, Op1, I.getPredicate(), I)) return NI; if (GEPOperator *GEP = dyn_cast<GEPOperator>(Op1)) - if (Instruction *NI = FoldGEPICmp(GEP, Op0, + if (Instruction *NI = foldGEPICmp(GEP, Op0, ICmpInst::getSwappedPredicate(I.getPredicate()), I)) return NI; @@ -3737,10 +4314,10 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Op0->getType()->isPointerTy() && I.isEquality()) { assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?"); if (auto *Alloca = dyn_cast<AllocaInst>(GetUnderlyingObject(Op0, DL))) - if (Instruction *New = FoldAllocaCmp(I, Alloca, Op1)) + if (Instruction *New = foldAllocaCmp(I, Alloca, Op1)) return New; if (auto *Alloca = dyn_cast<AllocaInst>(GetUnderlyingObject(Op1, DL))) - if (Instruction *New = FoldAllocaCmp(I, Alloca, Op0)) + if (Instruction *New = foldAllocaCmp(I, Alloca, Op0)) return New; } @@ -3780,318 +4357,24 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // For generality, we handle any zero-extension of any operand comparison // with a constant or another cast from the same type. if (isa<Constant>(Op1) || isa<CastInst>(Op1)) - if (Instruction *R = visitICmpInstWithCastAndCast(I)) + if (Instruction *R = foldICmpWithCastAndCast(I)) return R; } - // Special logic for binary operators. - BinaryOperator *BO0 = dyn_cast<BinaryOperator>(Op0); - BinaryOperator *BO1 = dyn_cast<BinaryOperator>(Op1); - if (BO0 || BO1) { - CmpInst::Predicate Pred = I.getPredicate(); - bool NoOp0WrapProblem = false, NoOp1WrapProblem = false; - if (BO0 && isa<OverflowingBinaryOperator>(BO0)) - NoOp0WrapProblem = ICmpInst::isEquality(Pred) || - (CmpInst::isUnsigned(Pred) && BO0->hasNoUnsignedWrap()) || - (CmpInst::isSigned(Pred) && BO0->hasNoSignedWrap()); - if (BO1 && isa<OverflowingBinaryOperator>(BO1)) - NoOp1WrapProblem = ICmpInst::isEquality(Pred) || - (CmpInst::isUnsigned(Pred) && BO1->hasNoUnsignedWrap()) || - (CmpInst::isSigned(Pred) && BO1->hasNoSignedWrap()); - - // Analyze the case when either Op0 or Op1 is an add instruction. - // Op0 = A + B (or A and B are null); Op1 = C + D (or C and D are null). - Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; - if (BO0 && BO0->getOpcode() == Instruction::Add) { - A = BO0->getOperand(0); - B = BO0->getOperand(1); - } - if (BO1 && BO1->getOpcode() == Instruction::Add) { - C = BO1->getOperand(0); - D = BO1->getOperand(1); - } - - // icmp (X+cst) < 0 --> X < -cst - if (NoOp0WrapProblem && ICmpInst::isSigned(Pred) && match(Op1, m_Zero())) - if (ConstantInt *RHSC = dyn_cast_or_null<ConstantInt>(B)) - if (!RHSC->isMinValue(/*isSigned=*/true)) - return new ICmpInst(Pred, A, ConstantExpr::getNeg(RHSC)); - - // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. - if ((A == Op1 || B == Op1) && NoOp0WrapProblem) - return new ICmpInst(Pred, A == Op1 ? B : A, - Constant::getNullValue(Op1->getType())); - - // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow. - if ((C == Op0 || D == Op0) && NoOp1WrapProblem) - return new ICmpInst(Pred, Constant::getNullValue(Op0->getType()), - C == Op0 ? D : C); - - // icmp (X+Y), (X+Z) -> icmp Y, Z for equalities or if there is no overflow. - if (A && C && (A == C || A == D || B == C || B == D) && - NoOp0WrapProblem && NoOp1WrapProblem && - // Try not to increase register pressure. - BO0->hasOneUse() && BO1->hasOneUse()) { - // Determine Y and Z in the form icmp (X+Y), (X+Z). - Value *Y, *Z; - if (A == C) { - // C + B == C + D -> B == D - Y = B; - Z = D; - } else if (A == D) { - // D + B == C + D -> B == C - Y = B; - Z = C; - } else if (B == C) { - // A + C == C + D -> A == D - Y = A; - Z = D; - } else { - assert(B == D); - // A + D == C + D -> A == C - Y = A; - Z = C; - } - return new ICmpInst(Pred, Y, Z); - } - - // icmp slt (X + -1), Y -> icmp sle X, Y - if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLT && - match(B, m_AllOnes())) - return new ICmpInst(CmpInst::ICMP_SLE, A, Op1); - - // icmp sge (X + -1), Y -> icmp sgt X, Y - if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGE && - match(B, m_AllOnes())) - return new ICmpInst(CmpInst::ICMP_SGT, A, Op1); - - // icmp sle (X + 1), Y -> icmp slt X, Y - if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLE && - match(B, m_One())) - return new ICmpInst(CmpInst::ICMP_SLT, A, Op1); - - // icmp sgt (X + 1), Y -> icmp sge X, Y - if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGT && - match(B, m_One())) - return new ICmpInst(CmpInst::ICMP_SGE, A, Op1); - - // icmp sgt X, (Y + -1) -> icmp sge X, Y - if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGT && - match(D, m_AllOnes())) - return new ICmpInst(CmpInst::ICMP_SGE, Op0, C); - - // icmp sle X, (Y + -1) -> icmp slt X, Y - if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLE && - match(D, m_AllOnes())) - return new ICmpInst(CmpInst::ICMP_SLT, Op0, C); - - // icmp sge X, (Y + 1) -> icmp sgt X, Y - if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGE && - match(D, m_One())) - return new ICmpInst(CmpInst::ICMP_SGT, Op0, C); - - // icmp slt X, (Y + 1) -> icmp sle X, Y - if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT && - match(D, m_One())) - return new ICmpInst(CmpInst::ICMP_SLE, Op0, C); - - // if C1 has greater magnitude than C2: - // icmp (X + C1), (Y + C2) -> icmp (X + C3), Y - // s.t. C3 = C1 - C2 - // - // if C2 has greater magnitude than C1: - // icmp (X + C1), (Y + C2) -> icmp X, (Y + C3) - // s.t. C3 = C2 - C1 - if (A && C && NoOp0WrapProblem && NoOp1WrapProblem && - (BO0->hasOneUse() || BO1->hasOneUse()) && !I.isUnsigned()) - if (ConstantInt *C1 = dyn_cast<ConstantInt>(B)) - if (ConstantInt *C2 = dyn_cast<ConstantInt>(D)) { - const APInt &AP1 = C1->getValue(); - const APInt &AP2 = C2->getValue(); - if (AP1.isNegative() == AP2.isNegative()) { - APInt AP1Abs = C1->getValue().abs(); - APInt AP2Abs = C2->getValue().abs(); - if (AP1Abs.uge(AP2Abs)) { - ConstantInt *C3 = Builder->getInt(AP1 - AP2); - Value *NewAdd = Builder->CreateNSWAdd(A, C3); - return new ICmpInst(Pred, NewAdd, C); - } else { - ConstantInt *C3 = Builder->getInt(AP2 - AP1); - Value *NewAdd = Builder->CreateNSWAdd(C, C3); - return new ICmpInst(Pred, A, NewAdd); - } - } - } - - - // Analyze the case when either Op0 or Op1 is a sub instruction. - // Op0 = A - B (or A and B are null); Op1 = C - D (or C and D are null). - A = nullptr; - B = nullptr; - C = nullptr; - D = nullptr; - if (BO0 && BO0->getOpcode() == Instruction::Sub) { - A = BO0->getOperand(0); - B = BO0->getOperand(1); - } - if (BO1 && BO1->getOpcode() == Instruction::Sub) { - C = BO1->getOperand(0); - D = BO1->getOperand(1); - } - - // icmp (X-Y), X -> icmp 0, Y for equalities or if there is no overflow. - if (A == Op1 && NoOp0WrapProblem) - return new ICmpInst(Pred, Constant::getNullValue(Op1->getType()), B); - - // icmp X, (X-Y) -> icmp Y, 0 for equalities or if there is no overflow. - if (C == Op0 && NoOp1WrapProblem) - return new ICmpInst(Pred, D, Constant::getNullValue(Op0->getType())); - - // icmp (Y-X), (Z-X) -> icmp Y, Z for equalities or if there is no overflow. - if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem && - // Try not to increase register pressure. - BO0->hasOneUse() && BO1->hasOneUse()) - return new ICmpInst(Pred, A, C); - - // icmp (X-Y), (X-Z) -> icmp Z, Y for equalities or if there is no overflow. - if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem && - // Try not to increase register pressure. - BO0->hasOneUse() && BO1->hasOneUse()) - return new ICmpInst(Pred, D, B); - - // icmp (0-X) < cst --> x > -cst - if (NoOp0WrapProblem && ICmpInst::isSigned(Pred)) { - Value *X; - if (match(BO0, m_Neg(m_Value(X)))) - if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) - if (!RHSC->isMinValue(/*isSigned=*/true)) - return new ICmpInst(I.getSwappedPredicate(), X, - ConstantExpr::getNeg(RHSC)); - } - - BinaryOperator *SRem = nullptr; - // icmp (srem X, Y), Y - if (BO0 && BO0->getOpcode() == Instruction::SRem && - Op1 == BO0->getOperand(1)) - SRem = BO0; - // icmp Y, (srem X, Y) - else if (BO1 && BO1->getOpcode() == Instruction::SRem && - Op0 == BO1->getOperand(1)) - SRem = BO1; - if (SRem) { - // We don't check hasOneUse to avoid increasing register pressure because - // the value we use is the same value this instruction was already using. - switch (SRem == BO0 ? ICmpInst::getSwappedPredicate(Pred) : Pred) { - default: break; - case ICmpInst::ICMP_EQ: - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - case ICmpInst::ICMP_NE: - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: - return new ICmpInst(ICmpInst::ICMP_SGT, SRem->getOperand(1), - Constant::getAllOnesValue(SRem->getType())); - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: - return new ICmpInst(ICmpInst::ICMP_SLT, SRem->getOperand(1), - Constant::getNullValue(SRem->getType())); - } - } - - if (BO0 && BO1 && BO0->getOpcode() == BO1->getOpcode() && - BO0->hasOneUse() && BO1->hasOneUse() && - BO0->getOperand(1) == BO1->getOperand(1)) { - switch (BO0->getOpcode()) { - default: break; - case Instruction::Add: - case Instruction::Sub: - case Instruction::Xor: - if (I.isEquality()) // a+x icmp eq/ne b+x --> a icmp b - return new ICmpInst(I.getPredicate(), BO0->getOperand(0), - BO1->getOperand(0)); - // icmp u/s (a ^ signbit), (b ^ signbit) --> icmp s/u a, b - if (ConstantInt *CI = dyn_cast<ConstantInt>(BO0->getOperand(1))) { - if (CI->getValue().isSignBit()) { - ICmpInst::Predicate Pred = I.isSigned() - ? I.getUnsignedPredicate() - : I.getSignedPredicate(); - return new ICmpInst(Pred, BO0->getOperand(0), - BO1->getOperand(0)); - } - - if (BO0->getOpcode() == Instruction::Xor && CI->isMaxValue(true)) { - ICmpInst::Predicate Pred = I.isSigned() - ? I.getUnsignedPredicate() - : I.getSignedPredicate(); - Pred = I.getSwappedPredicate(Pred); - return new ICmpInst(Pred, BO0->getOperand(0), - BO1->getOperand(0)); - } - } - break; - case Instruction::Mul: - if (!I.isEquality()) - break; - - if (ConstantInt *CI = dyn_cast<ConstantInt>(BO0->getOperand(1))) { - // a * Cst icmp eq/ne b * Cst --> a & Mask icmp b & Mask - // Mask = -1 >> count-trailing-zeros(Cst). - if (!CI->isZero() && !CI->isOne()) { - const APInt &AP = CI->getValue(); - ConstantInt *Mask = ConstantInt::get(I.getContext(), - APInt::getLowBitsSet(AP.getBitWidth(), - AP.getBitWidth() - - AP.countTrailingZeros())); - Value *And1 = Builder->CreateAnd(BO0->getOperand(0), Mask); - Value *And2 = Builder->CreateAnd(BO1->getOperand(0), Mask); - return new ICmpInst(I.getPredicate(), And1, And2); - } - } - break; - case Instruction::UDiv: - case Instruction::LShr: - if (I.isSigned()) - break; - // fall-through - case Instruction::SDiv: - case Instruction::AShr: - if (!BO0->isExact() || !BO1->isExact()) - break; - return new ICmpInst(I.getPredicate(), BO0->getOperand(0), - BO1->getOperand(0)); - case Instruction::Shl: { - bool NUW = BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap(); - bool NSW = BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap(); - if (!NUW && !NSW) - break; - if (!NSW && I.isSigned()) - break; - return new ICmpInst(I.getPredicate(), BO0->getOperand(0), - BO1->getOperand(0)); - } - } - } - - if (BO0) { - // Transform A & (L - 1) `ult` L --> L != 0 - auto LSubOne = m_Add(m_Specific(Op1), m_AllOnes()); - auto BitwiseAnd = - m_CombineOr(m_And(m_Value(), LSubOne), m_And(LSubOne, m_Value())); + if (Instruction *Res = foldICmpBinOp(I)) + return Res; - if (match(BO0, BitwiseAnd) && I.getPredicate() == ICmpInst::ICMP_ULT) { - auto *Zero = Constant::getNullValue(BO0->getType()); - return new ICmpInst(ICmpInst::ICMP_NE, Op1, Zero); - } - } - } + if (Instruction *Res = foldICmpWithMinMax(I)) + return Res; - { Value *A, *B; + { + Value *A, *B; // Transform (A & ~B) == 0 --> (A & B) != 0 // and (A & ~B) != 0 --> (A & B) == 0 // if A is a power of 2. if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && match(Op1, m_Zero()) && - isKnownToBeAPowerOfTwo(A, DL, false, 0, AC, &I, DT) && I.isEquality()) + isKnownToBeAPowerOfTwo(A, DL, false, 0, &AC, &I, &DT) && I.isEquality()) return new ICmpInst(I.getInversePredicate(), Builder->CreateAnd(A, B), Op1); @@ -4120,149 +4403,17 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // (zext a) * (zext b) --> llvm.umul.with.overflow. if (match(Op0, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { - if (Instruction *R = ProcessUMulZExtIdiom(I, Op0, Op1, *this)) + if (Instruction *R = processUMulZExtIdiom(I, Op0, Op1, *this)) return R; } if (match(Op1, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { - if (Instruction *R = ProcessUMulZExtIdiom(I, Op1, Op0, *this)) + if (Instruction *R = processUMulZExtIdiom(I, Op1, Op0, *this)) return R; } } - if (I.isEquality()) { - Value *A, *B, *C, *D; - - if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) { - if (A == Op1 || B == Op1) { // (A^B) == A -> B == 0 - Value *OtherVal = A == Op1 ? B : A; - return new ICmpInst(I.getPredicate(), OtherVal, - Constant::getNullValue(A->getType())); - } - - if (match(Op1, m_Xor(m_Value(C), m_Value(D)))) { - // A^c1 == C^c2 --> A == C^(c1^c2) - ConstantInt *C1, *C2; - if (match(B, m_ConstantInt(C1)) && - match(D, m_ConstantInt(C2)) && Op1->hasOneUse()) { - Constant *NC = Builder->getInt(C1->getValue() ^ C2->getValue()); - Value *Xor = Builder->CreateXor(C, NC); - return new ICmpInst(I.getPredicate(), A, Xor); - } - - // A^B == A^D -> B == D - if (A == C) return new ICmpInst(I.getPredicate(), B, D); - if (A == D) return new ICmpInst(I.getPredicate(), B, C); - if (B == C) return new ICmpInst(I.getPredicate(), A, D); - if (B == D) return new ICmpInst(I.getPredicate(), A, C); - } - } - - if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && - (A == Op0 || B == Op0)) { - // A == (A^B) -> B == 0 - Value *OtherVal = A == Op0 ? B : A; - return new ICmpInst(I.getPredicate(), OtherVal, - Constant::getNullValue(A->getType())); - } - - // (X&Z) == (Y&Z) -> (X^Y) & Z == 0 - if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B)))) && - match(Op1, m_OneUse(m_And(m_Value(C), m_Value(D))))) { - Value *X = nullptr, *Y = nullptr, *Z = nullptr; - - if (A == C) { - X = B; Y = D; Z = A; - } else if (A == D) { - X = B; Y = C; Z = A; - } else if (B == C) { - X = A; Y = D; Z = B; - } else if (B == D) { - X = A; Y = C; Z = B; - } - - if (X) { // Build (X^Y) & Z - Op1 = Builder->CreateXor(X, Y); - Op1 = Builder->CreateAnd(Op1, Z); - I.setOperand(0, Op1); - I.setOperand(1, Constant::getNullValue(Op1->getType())); - return &I; - } - } - - // Transform (zext A) == (B & (1<<X)-1) --> A == (trunc B) - // and (B & (1<<X)-1) == (zext A) --> A == (trunc B) - ConstantInt *Cst1; - if ((Op0->hasOneUse() && - match(Op0, m_ZExt(m_Value(A))) && - match(Op1, m_And(m_Value(B), m_ConstantInt(Cst1)))) || - (Op1->hasOneUse() && - match(Op0, m_And(m_Value(B), m_ConstantInt(Cst1))) && - match(Op1, m_ZExt(m_Value(A))))) { - APInt Pow2 = Cst1->getValue() + 1; - if (Pow2.isPowerOf2() && isa<IntegerType>(A->getType()) && - Pow2.logBase2() == cast<IntegerType>(A->getType())->getBitWidth()) - return new ICmpInst(I.getPredicate(), A, - Builder->CreateTrunc(B, A->getType())); - } - - // (A >> C) == (B >> C) --> (A^B) u< (1 << C) - // For lshr and ashr pairs. - if ((match(Op0, m_OneUse(m_LShr(m_Value(A), m_ConstantInt(Cst1)))) && - match(Op1, m_OneUse(m_LShr(m_Value(B), m_Specific(Cst1))))) || - (match(Op0, m_OneUse(m_AShr(m_Value(A), m_ConstantInt(Cst1)))) && - match(Op1, m_OneUse(m_AShr(m_Value(B), m_Specific(Cst1)))))) { - unsigned TypeBits = Cst1->getBitWidth(); - unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits); - if (ShAmt < TypeBits && ShAmt != 0) { - ICmpInst::Predicate Pred = I.getPredicate() == ICmpInst::ICMP_NE - ? ICmpInst::ICMP_UGE - : ICmpInst::ICMP_ULT; - Value *Xor = Builder->CreateXor(A, B, I.getName() + ".unshifted"); - APInt CmpVal = APInt::getOneBitSet(TypeBits, ShAmt); - return new ICmpInst(Pred, Xor, Builder->getInt(CmpVal)); - } - } - - // (A << C) == (B << C) --> ((A^B) & (~0U >> C)) == 0 - if (match(Op0, m_OneUse(m_Shl(m_Value(A), m_ConstantInt(Cst1)))) && - match(Op1, m_OneUse(m_Shl(m_Value(B), m_Specific(Cst1))))) { - unsigned TypeBits = Cst1->getBitWidth(); - unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits); - if (ShAmt < TypeBits && ShAmt != 0) { - Value *Xor = Builder->CreateXor(A, B, I.getName() + ".unshifted"); - APInt AndVal = APInt::getLowBitsSet(TypeBits, TypeBits - ShAmt); - Value *And = Builder->CreateAnd(Xor, Builder->getInt(AndVal), - I.getName() + ".mask"); - return new ICmpInst(I.getPredicate(), And, - Constant::getNullValue(Cst1->getType())); - } - } - - // Transform "icmp eq (trunc (lshr(X, cst1)), cst" to - // "icmp (and X, mask), cst" - uint64_t ShAmt = 0; - if (Op0->hasOneUse() && - match(Op0, m_Trunc(m_OneUse(m_LShr(m_Value(A), - m_ConstantInt(ShAmt))))) && - match(Op1, m_ConstantInt(Cst1)) && - // Only do this when A has multiple uses. This is most important to do - // when it exposes other optimizations. - !A->hasOneUse()) { - unsigned ASize =cast<IntegerType>(A->getType())->getPrimitiveSizeInBits(); - - if (ShAmt < ASize) { - APInt MaskV = - APInt::getLowBitsSet(ASize, Op0->getType()->getPrimitiveSizeInBits()); - MaskV <<= ShAmt; - - APInt CmpV = Cst1->getValue().zext(ASize); - CmpV <<= ShAmt; - - Value *Mask = Builder->CreateAnd(A, Builder->getInt(MaskV)); - return new ICmpInst(I.getPredicate(), Mask, Builder->getInt(CmpV)); - } - } - } + if (Instruction *Res = foldICmpEquality(I)) + return Res; // The 'cmpxchg' instruction returns an aggregate containing the old value and // an i1 which indicates whether or not we successfully did the swap. @@ -4284,18 +4435,17 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Value *X; ConstantInt *Cst; // icmp X+Cst, X if (match(Op0, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op1 == X) - return FoldICmpAddOpCst(I, X, Cst, I.getPredicate()); + return foldICmpAddOpConst(I, X, Cst, I.getPredicate()); // icmp X, X+Cst if (match(Op1, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op0 == X) - return FoldICmpAddOpCst(I, X, Cst, I.getSwappedPredicate()); + return foldICmpAddOpConst(I, X, Cst, I.getSwappedPredicate()); } return Changed ? &I : nullptr; } /// Fold fcmp ([us]itofp x, cst) if possible. -Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, - Instruction *LHSI, +Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, Constant *RHSC) { if (!isa<ConstantFP>(RHSC)) return nullptr; const APFloat &RHS = cast<ConstantFP>(RHSC)->getValueAPF(); @@ -4339,21 +4489,21 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, // This would allow us to handle (fptosi (x >>s 62) to float) if x is i64 f.e. unsigned InputSize = IntTy->getScalarSizeInBits(); - // Following test does NOT adjust InputSize downwards for signed inputs, - // because the most negative value still requires all the mantissa bits + // Following test does NOT adjust InputSize downwards for signed inputs, + // because the most negative value still requires all the mantissa bits // to distinguish it from one less than that value. if ((int)InputSize > MantissaWidth) { // Conversion would lose accuracy. Check if loss can impact comparison. int Exp = ilogb(RHS); if (Exp == APFloat::IEK_Inf) { int MaxExponent = ilogb(APFloat::getLargest(RHS.getSemantics())); - if (MaxExponent < (int)InputSize - !LHSUnsigned) + if (MaxExponent < (int)InputSize - !LHSUnsigned) // Conversion could create infinity. return nullptr; } else { - // Note that if RHS is zero or NaN, then Exp is negative + // Note that if RHS is zero or NaN, then Exp is negative // and first condition is trivially false. - if (MantissaWidth <= Exp && Exp <= (int)InputSize - !LHSUnsigned) + if (MantissaWidth <= Exp && Exp <= (int)InputSize - !LHSUnsigned) // Conversion could affect comparison. return nullptr; } @@ -4547,7 +4697,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, - I.getFastMathFlags(), DL, TLI, DT, AC, &I)) + I.getFastMathFlags(), DL, &TLI, &DT, &AC, &I)) return replaceInstUsesWith(I, V); // Simplify 'fcmp pred X, X' @@ -4601,17 +4751,17 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { const fltSemantics *Sem; // FIXME: This shouldn't be here. if (LHSExt->getSrcTy()->isHalfTy()) - Sem = &APFloat::IEEEhalf; + Sem = &APFloat::IEEEhalf(); else if (LHSExt->getSrcTy()->isFloatTy()) - Sem = &APFloat::IEEEsingle; + Sem = &APFloat::IEEEsingle(); else if (LHSExt->getSrcTy()->isDoubleTy()) - Sem = &APFloat::IEEEdouble; + Sem = &APFloat::IEEEdouble(); else if (LHSExt->getSrcTy()->isFP128Ty()) - Sem = &APFloat::IEEEquad; + Sem = &APFloat::IEEEquad(); else if (LHSExt->getSrcTy()->isX86_FP80Ty()) - Sem = &APFloat::x87DoubleExtended; + Sem = &APFloat::x87DoubleExtended(); else if (LHSExt->getSrcTy()->isPPC_FP128Ty()) - Sem = &APFloat::PPCDoubleDouble; + Sem = &APFloat::PPCDoubleDouble(); else break; @@ -4641,7 +4791,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { break; case Instruction::SIToFP: case Instruction::UIToFP: - if (Instruction *NV = FoldFCmp_IntToFP_Cst(I, LHSI, RHSC)) + if (Instruction *NV = foldFCmpIntToFPConst(I, LHSI, RHSC)) return NV; break; case Instruction::FSub: { @@ -4658,7 +4808,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) if (GV->isConstant() && GV->hasDefinitiveInitializer() && !cast<LoadInst>(LHSI)->isVolatile()) - if (Instruction *Res = FoldCmpLoadFromIndexedGlobal(GEP, GV, I)) + if (Instruction *Res = foldCmpLoadFromIndexedGlobal(GEP, GV, I)) return Res; } break; @@ -4667,7 +4817,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { break; CallInst *CI = cast<CallInst>(LHSI); - Intrinsic::ID IID = getIntrinsicForCallSite(CI, TLI); + Intrinsic::ID IID = getIntrinsicForCallSite(CI, &TLI); if (IID != Intrinsic::fabs) break; diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index aa421ff594fb..3cefe715e567 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -84,6 +84,24 @@ static inline bool IsFreeToInvert(Value *V, bool WillInvertAllUses) { if (isa<ConstantInt>(V)) return true; + // A vector of constant integers can be inverted easily. + Constant *CV; + if (V->getType()->isVectorTy() && match(V, PatternMatch::m_Constant(CV))) { + unsigned NumElts = V->getType()->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = CV->getAggregateElement(i); + if (!Elt) + return false; + + if (isa<UndefValue>(Elt)) + continue; + + if (!isa<ConstantInt>(Elt)) + return false; + } + return true; + } + // Compares can be inverted if all of their uses are being modified to use the // ~V. if (isa<CmpInst>(V)) @@ -135,33 +153,10 @@ IntrinsicIDToOverflowCheckFlavor(unsigned ID) { } } -/// \brief An IRBuilder inserter that adds new instructions to the instcombine -/// worklist. -class LLVM_LIBRARY_VISIBILITY InstCombineIRInserter - : public IRBuilderDefaultInserter { - InstCombineWorklist &Worklist; - AssumptionCache *AC; - -public: - InstCombineIRInserter(InstCombineWorklist &WL, AssumptionCache *AC) - : Worklist(WL), AC(AC) {} - - void InsertHelper(Instruction *I, const Twine &Name, BasicBlock *BB, - BasicBlock::iterator InsertPt) const { - IRBuilderDefaultInserter::InsertHelper(I, Name, BB, InsertPt); - Worklist.Add(I); - - using namespace llvm::PatternMatch; - if (match(I, m_Intrinsic<Intrinsic::assume>())) - AC->registerAssumption(cast<CallInst>(I)); - } -}; - /// \brief The core instruction combiner logic. /// /// This class provides both the logic to recursively visit instructions and -/// combine them, as well as the pass infrastructure for running this as part -/// of the LLVM pass pipeline. +/// combine them. class LLVM_LIBRARY_VISIBILITY InstCombiner : public InstVisitor<InstCombiner, Instruction *> { // FIXME: These members shouldn't be public. @@ -171,7 +166,7 @@ public: /// \brief An IRBuilder that automatically inserts new instructions into the /// worklist. - typedef IRBuilder<TargetFolder, InstCombineIRInserter> BuilderTy; + typedef IRBuilder<TargetFolder, IRBuilderCallbackInserter> BuilderTy; BuilderTy *Builder; private: @@ -183,10 +178,9 @@ private: AliasAnalysis *AA; // Required analyses. - // FIXME: These can never be null and should be references. - AssumptionCache *AC; - TargetLibraryInfo *TLI; - DominatorTree *DT; + AssumptionCache &AC; + TargetLibraryInfo &TLI; + DominatorTree &DT; const DataLayout &DL; // Optional analyses. When non-null, these can both be used to do better @@ -198,8 +192,8 @@ private: public: InstCombiner(InstCombineWorklist &Worklist, BuilderTy *Builder, bool MinimizeSize, bool ExpensiveCombines, AliasAnalysis *AA, - AssumptionCache *AC, TargetLibraryInfo *TLI, - DominatorTree *DT, const DataLayout &DL, LoopInfo *LI) + AssumptionCache &AC, TargetLibraryInfo &TLI, + DominatorTree &DT, const DataLayout &DL, LoopInfo *LI) : Worklist(Worklist), Builder(Builder), MinimizeSize(MinimizeSize), ExpensiveCombines(ExpensiveCombines), AA(AA), AC(AC), TLI(TLI), DT(DT), DL(DL), LI(LI), MadeIRChange(false) {} @@ -209,15 +203,15 @@ public: /// \returns true if the IR is changed. bool run(); - AssumptionCache *getAssumptionCache() const { return AC; } + AssumptionCache &getAssumptionCache() const { return AC; } const DataLayout &getDataLayout() const { return DL; } - DominatorTree *getDominatorTree() const { return DT; } + DominatorTree &getDominatorTree() const { return DT; } LoopInfo *getLoopInfo() const { return LI; } - TargetLibraryInfo *getTargetLibraryInfo() const { return TLI; } + TargetLibraryInfo &getTargetLibraryInfo() const { return TLI; } // Visitation implementation - Implement instruction combining for different // instruction types. The semantics are as follows: @@ -262,29 +256,8 @@ public: Instruction *visitAShr(BinaryOperator &I); Instruction *visitLShr(BinaryOperator &I); Instruction *commonShiftTransforms(BinaryOperator &I); - Instruction *FoldFCmp_IntToFP_Cst(FCmpInst &I, Instruction *LHSI, - Constant *RHSC); - Instruction *FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, - GlobalVariable *GV, CmpInst &ICI, - ConstantInt *AndCst = nullptr); Instruction *visitFCmpInst(FCmpInst &I); Instruction *visitICmpInst(ICmpInst &I); - Instruction *visitICmpInstWithCastAndCast(ICmpInst &ICI); - Instruction *visitICmpInstWithInstAndIntCst(ICmpInst &ICI, Instruction *LHS, - ConstantInt *RHS); - Instruction *FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, - ConstantInt *DivRHS); - Instruction *FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *DivI, - ConstantInt *DivRHS); - Instruction *FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, - ConstantInt *CI1, ConstantInt *CI2); - Instruction *FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, - ConstantInt *CI1, ConstantInt *CI2); - Instruction *FoldICmpAddOpCst(Instruction &ICI, Value *X, ConstantInt *CI, - ICmpInst::Predicate Pred); - Instruction *FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, - ICmpInst::Predicate Cond, Instruction &I); - Instruction *FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, Value *Other); Instruction *FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I); Instruction *commonCastTransforms(CastInst &CI); @@ -302,14 +275,8 @@ public: Instruction *visitIntToPtr(IntToPtrInst &CI); Instruction *visitBitCast(BitCastInst &CI); Instruction *visitAddrSpaceCast(AddrSpaceCastInst &CI); - Instruction *FoldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI); - Instruction *FoldSelectIntoOp(SelectInst &SI, Value *, Value *); - Instruction *FoldSPFofSPF(Instruction *Inner, SelectPatternFlavor SPF1, - Value *A, Value *B, Instruction &Outer, - SelectPatternFlavor SPF2, Value *C); Instruction *FoldItoFPtoI(Instruction &FI); Instruction *visitSelectInst(SelectInst &SI); - Instruction *visitSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI); Instruction *visitCallInst(CallInst &CI); Instruction *visitInvokeInst(InvokeInst &II); @@ -333,16 +300,16 @@ public: Instruction *visitVAStartInst(VAStartInst &I); Instruction *visitVACopyInst(VACopyInst &I); - // visitInstruction - Specify what to return for unhandled instructions... + /// Specify what to return for unhandled instructions. Instruction *visitInstruction(Instruction &I) { return nullptr; } - // True when DB dominates all uses of DI execpt UI. - // UI must be in the same block as DI. - // The routine checks that the DI parent and DB are different. + /// True when DB dominates all uses of DI except UI. + /// UI must be in the same block as DI. + /// The routine checks that the DI parent and DB are different. bool dominatesAllUses(const Instruction *DI, const Instruction *UI, const BasicBlock *DB) const; - // Replace select with select operand SIOpd in SI-ICmp sequence when possible + /// Try to replace select with select operand SIOpd in SI-ICmp sequence. bool replacedSelectWithOperand(SelectInst *SI, const ICmpInst *Icmp, const unsigned SIOpd); @@ -355,14 +322,16 @@ private: SmallVectorImpl<Value *> &NewIndices); Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI); - /// \brief Classify whether a cast is worth optimizing. + /// Classify whether a cast is worth optimizing. + /// + /// This is a helper to decide whether the simplification of + /// logic(cast(A), cast(B)) to cast(logic(A, B)) should be performed. /// - /// Returns true if the cast from "V to Ty" actually results in any code - /// being generated and is interesting to optimize out. If the cast can be - /// eliminated by some other simple transformation, we prefer to do the - /// simplification first. - bool ShouldOptimizeCast(Instruction::CastOps opcode, const Value *V, - Type *Ty); + /// \param CI The cast we are interested in. + /// + /// \return true if this cast actually results in any code being generated and + /// if it cannot already be eliminated by some other transformation. + bool shouldOptimizeCast(CastInst *CI); /// \brief Try to optimize a sequence of instructions checking if an operation /// on LHS and RHS overflows. @@ -385,8 +354,22 @@ private: bool transformConstExprCastCall(CallSite CS); Instruction *transformCallThroughTrampoline(CallSite CS, IntrinsicInst *Tramp); - Instruction *transformZExtICmp(ICmpInst *ICI, Instruction &CI, - bool DoXform = true); + + /// Transform (zext icmp) to bitwise / integer operations in order to + /// eliminate it. + /// + /// \param ICI The icmp of the (zext icmp) pair we are interested in. + /// \parem CI The zext of the (zext icmp) pair we are interested in. + /// \param DoTransform Pass false to just test whether the given (zext icmp) + /// would be transformed. Pass true to actually perform the transformation. + /// + /// \return null if the transformation cannot be performed. If the + /// transformation can be performed the new instruction that replaces the + /// (zext icmp) pair will be returned (if \p DoTransform is false the + /// unmodified \p ICI will be returned in this case). + Instruction *transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, + bool DoTransform = true); + Instruction *transformSExtICmp(ICmpInst *ICI, Instruction &CI); bool WillNotOverflowSignedAdd(Value *LHS, Value *RHS, Instruction &CxtI); bool WillNotOverflowSignedSub(Value *LHS, Value *RHS, Instruction &CxtI); @@ -396,6 +379,21 @@ private: Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN); Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask); Instruction *foldCastedBitwiseLogic(BinaryOperator &I); + Instruction *shrinkBitwiseLogic(TruncInst &Trunc); + Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); + + /// Determine if a pair of casts can be replaced by a single cast. + /// + /// \param CI1 The first of a pair of casts. + /// \param CI2 The second of a pair of casts. + /// + /// \return 0 if the cast pair cannot be eliminated, otherwise returns an + /// Instruction::CastOps value for a cast that can replace the pair, casting + /// CI1->getSrcTy() to CI2->getDstTy(). + /// + /// \see CastInst::isEliminableCastPair + Instruction::CastOps isEliminableCastPair(const CastInst *CI1, + const CastInst *CI2); public: /// \brief Inserts an instruction \p New before instruction \p Old @@ -476,30 +474,30 @@ public: void computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, unsigned Depth, Instruction *CxtI) const { - return llvm::computeKnownBits(V, KnownZero, KnownOne, DL, Depth, AC, CxtI, - DT); + return llvm::computeKnownBits(V, KnownZero, KnownOne, DL, Depth, &AC, CxtI, + &DT); } bool MaskedValueIsZero(Value *V, const APInt &Mask, unsigned Depth = 0, Instruction *CxtI = nullptr) const { - return llvm::MaskedValueIsZero(V, Mask, DL, Depth, AC, CxtI, DT); + return llvm::MaskedValueIsZero(V, Mask, DL, Depth, &AC, CxtI, &DT); } unsigned ComputeNumSignBits(Value *Op, unsigned Depth = 0, Instruction *CxtI = nullptr) const { - return llvm::ComputeNumSignBits(Op, DL, Depth, AC, CxtI, DT); + return llvm::ComputeNumSignBits(Op, DL, Depth, &AC, CxtI, &DT); } void ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, unsigned Depth = 0, Instruction *CxtI = nullptr) const { - return llvm::ComputeSignBit(V, KnownZero, KnownOne, DL, Depth, AC, CxtI, - DT); + return llvm::ComputeSignBit(V, KnownZero, KnownOne, DL, Depth, &AC, CxtI, + &DT); } OverflowResult computeOverflowForUnsignedMul(Value *LHS, Value *RHS, const Instruction *CxtI) { - return llvm::computeOverflowForUnsignedMul(LHS, RHS, DL, AC, CxtI, DT); + return llvm::computeOverflowForUnsignedMul(LHS, RHS, DL, &AC, CxtI, &DT); } OverflowResult computeOverflowForUnsignedAdd(Value *LHS, Value *RHS, const Instruction *CxtI) { - return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, AC, CxtI, DT); + return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, &AC, CxtI, &DT); } private: @@ -554,13 +552,82 @@ private: Instruction *FoldPHIArgLoadIntoPHI(PHINode &PN); Instruction *FoldPHIArgZextsIntoPHI(PHINode &PN); + /// Helper function for FoldPHIArgXIntoPHI() to get debug location for the + /// folded operation. + DebugLoc PHIArgMergedDebugLoc(PHINode &PN); + + Instruction *foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, + ICmpInst::Predicate Cond, Instruction &I); + Instruction *foldAllocaCmp(ICmpInst &ICI, const AllocaInst *Alloca, + const Value *Other); + Instruction *foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, + GlobalVariable *GV, CmpInst &ICI, + ConstantInt *AndCst = nullptr); + Instruction *foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, + Constant *RHSC); + Instruction *foldICmpAddOpConst(Instruction &ICI, Value *X, ConstantInt *CI, + ICmpInst::Predicate Pred); + Instruction *foldICmpWithCastAndCast(ICmpInst &ICI); + + Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp); + Instruction *foldICmpWithConstant(ICmpInst &Cmp); + Instruction *foldICmpInstWithConstant(ICmpInst &Cmp); + Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp); + Instruction *foldICmpBinOp(ICmpInst &Cmp); + Instruction *foldICmpEquality(ICmpInst &Cmp); + + Instruction *foldICmpTruncConstant(ICmpInst &Cmp, Instruction *Trunc, + const APInt *C); + Instruction *foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And, + const APInt *C); + Instruction *foldICmpXorConstant(ICmpInst &Cmp, BinaryOperator *Xor, + const APInt *C); + Instruction *foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, + const APInt *C); + Instruction *foldICmpMulConstant(ICmpInst &Cmp, BinaryOperator *Mul, + const APInt *C); + Instruction *foldICmpShlConstant(ICmpInst &Cmp, BinaryOperator *Shl, + const APInt *C); + Instruction *foldICmpShrConstant(ICmpInst &Cmp, BinaryOperator *Shr, + const APInt *C); + Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv, + const APInt *C); + Instruction *foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div, + const APInt *C); + Instruction *foldICmpSubConstant(ICmpInst &Cmp, BinaryOperator *Sub, + const APInt *C); + Instruction *foldICmpAddConstant(ICmpInst &Cmp, BinaryOperator *Add, + const APInt *C); + Instruction *foldICmpAndConstConst(ICmpInst &Cmp, BinaryOperator *And, + const APInt *C1); + Instruction *foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, + const APInt *C1, const APInt *C2); + Instruction *foldICmpShrConstConst(ICmpInst &I, Value *ShAmt, const APInt &C1, + const APInt &C2); + Instruction *foldICmpShlConstConst(ICmpInst &I, Value *ShAmt, const APInt &C1, + const APInt &C2); + + Instruction *foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, + BinaryOperator *BO, + const APInt *C); + Instruction *foldICmpIntrinsicWithConstant(ICmpInst &ICI, const APInt *C); + + // Helpers of visitSelectInst(). + Instruction *foldSelectExtConst(SelectInst &Sel); + Instruction *foldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI); + Instruction *foldSelectIntoOp(SelectInst &SI, Value *, Value *); + Instruction *foldSPFofSPF(Instruction *Inner, SelectPatternFlavor SPF1, + Value *A, Value *B, Instruction &Outer, + SelectPatternFlavor SPF2, Value *C); + Instruction *foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI); + Instruction *OptAndOp(Instruction *Op, ConstantInt *OpRHS, ConstantInt *AndRHS, BinaryOperator &TheAnd); Value *FoldLogicalPlusAnd(Value *LHS, Value *RHS, ConstantInt *Mask, bool isSub, Instruction &I); - Value *InsertRangeTest(Value *V, Constant *Lo, Constant *Hi, bool isSigned, - bool Inside); + Value *insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi, + bool isSigned, bool Inside); Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI); Instruction *MatchBSwap(BinaryOperator &I); bool SimplifyStoreAtEndOfBlock(StoreInst &SI); diff --git a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index d88456ee4adc..5276bee4e0a2 100644 --- a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/Loads.h" +#include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/IntrinsicInst.h" @@ -59,14 +60,14 @@ isOnlyCopiedFromConstantGlobal(Value *V, MemTransferInst *&TheCopy, // eliminate the markers. SmallVector<std::pair<Value *, bool>, 35> ValuesToInspect; - ValuesToInspect.push_back(std::make_pair(V, false)); + ValuesToInspect.emplace_back(V, false); while (!ValuesToInspect.empty()) { auto ValuePair = ValuesToInspect.pop_back_val(); const bool IsOffset = ValuePair.second; for (auto &U : ValuePair.first->uses()) { - Instruction *I = cast<Instruction>(U.getUser()); + auto *I = cast<Instruction>(U.getUser()); - if (LoadInst *LI = dyn_cast<LoadInst>(I)) { + if (auto *LI = dyn_cast<LoadInst>(I)) { // Ignore non-volatile loads, they are always ok. if (!LI->isSimple()) return false; continue; @@ -74,14 +75,13 @@ isOnlyCopiedFromConstantGlobal(Value *V, MemTransferInst *&TheCopy, if (isa<BitCastInst>(I) || isa<AddrSpaceCastInst>(I)) { // If uses of the bitcast are ok, we are ok. - ValuesToInspect.push_back(std::make_pair(I, IsOffset)); + ValuesToInspect.emplace_back(I, IsOffset); continue; } - if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I)) { + if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { // If the GEP has all zero indices, it doesn't offset the pointer. If it // doesn't, it does. - ValuesToInspect.push_back( - std::make_pair(I, IsOffset || !GEP->hasAllZeroIndices())); + ValuesToInspect.emplace_back(I, IsOffset || !GEP->hasAllZeroIndices()); continue; } @@ -286,7 +286,7 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { SmallVector<Instruction *, 4> ToDelete; if (MemTransferInst *Copy = isOnlyCopiedFromConstantGlobal(&AI, ToDelete)) { unsigned SourceAlign = getOrEnforceKnownAlignment( - Copy->getSource(), AI.getAlignment(), DL, &AI, AC, DT); + Copy->getSource(), AI.getAlignment(), DL, &AI, &AC, &DT); if (AI.getAlignment() <= SourceAlign) { DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n'); DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); @@ -308,6 +308,11 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { return visitAllocSite(AI); } +// Are we allowed to form a atomic load or store of this type? +static bool isSupportedAtomicType(Type *Ty) { + return Ty->isIntegerTy() || Ty->isPointerTy() || Ty->isFloatingPointTy(); +} + /// \brief Helper to combine a load to a new type. /// /// This just does the work of combining a load to a new type. It handles @@ -319,6 +324,9 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { /// point the \c InstCombiner currently is using. static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewTy, const Twine &Suffix = "") { + assert((!LI.isAtomic() || isSupportedAtomicType(NewTy)) && + "can't fold an atomic load to requested type"); + Value *Ptr = LI.getPointerOperand(); unsigned AS = LI.getPointerAddressSpace(); SmallVector<std::pair<unsigned, MDNode *>, 8> MD; @@ -380,8 +388,16 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT break; case LLVMContext::MD_range: // FIXME: It would be nice to propagate this in some way, but the type - // conversions make it hard. If the new type is a pointer, we could - // translate it to !nonnull metadata. + // conversions make it hard. + + // If it's a pointer now and the range does not contain 0, make it !nonnull. + if (NewTy->isPointerTy()) { + unsigned BitWidth = IC.getDataLayout().getTypeSizeInBits(NewTy); + if (!getConstantRangeFromMetadata(*N).contains(APInt(BitWidth, 0))) { + MDNode *NN = MDNode::get(LI.getContext(), None); + NewLoad->setMetadata(LLVMContext::MD_nonnull, NN); + } + } break; } } @@ -392,6 +408,9 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT /// /// Returns the newly created store instruction. static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value *V) { + assert((!SI.isAtomic() || isSupportedAtomicType(V->getType())) && + "can't fold an atomic store of requested type"); + Value *Ptr = SI.getPointerOperand(); unsigned AS = SI.getPointerAddressSpace(); SmallVector<std::pair<unsigned, MDNode *>, 8> MD; @@ -466,6 +485,10 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { if (LI.use_empty()) return nullptr; + // swifterror values can't be bitcasted. + if (LI.getPointerOperand()->isSwiftError()) + return nullptr; + Type *Ty = LI.getType(); const DataLayout &DL = IC.getDataLayout(); @@ -475,8 +498,9 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { // size is a legal integer type. if (!Ty->isIntegerTy() && Ty->isSized() && DL.isLegalInteger(DL.getTypeStoreSizeInBits(Ty)) && - DL.getTypeStoreSizeInBits(Ty) == DL.getTypeSizeInBits(Ty)) { - if (std::all_of(LI.user_begin(), LI.user_end(), [&LI](User *U) { + DL.getTypeStoreSizeInBits(Ty) == DL.getTypeSizeInBits(Ty) && + !DL.isNonIntegralPointerType(Ty)) { + if (all_of(LI.users(), [&LI](User *U) { auto *SI = dyn_cast<StoreInst>(U); return SI && SI->getPointerOperand() != &LI; })) { @@ -501,14 +525,14 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { // as long as those are noops (i.e., the source or dest type have the same // bitwidth as the target's pointers). if (LI.hasOneUse()) - if (auto* CI = dyn_cast<CastInst>(LI.user_back())) { - if (CI->isNoopCast(DL)) { - LoadInst *NewLoad = combineLoadToNewType(IC, LI, CI->getDestTy()); - CI->replaceAllUsesWith(NewLoad); - IC.eraseInstFromFunction(*CI); - return &LI; - } - } + if (auto* CI = dyn_cast<CastInst>(LI.user_back())) + if (CI->isNoopCast(DL)) + if (!LI.isAtomic() || isSupportedAtomicType(CI->getDestTy())) { + LoadInst *NewLoad = combineLoadToNewType(IC, LI, CI->getDestTy()); + CI->replaceAllUsesWith(NewLoad); + IC.eraseInstFromFunction(*CI); + return &LI; + } // FIXME: We should also canonicalize loads of vectors when their elements are // cast to other types. @@ -802,7 +826,7 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { // Attempt to improve the alignment. unsigned KnownAlign = getOrEnforceKnownAlignment( - Op, DL.getPrefTypeAlignment(LI.getType()), DL, &LI, AC, DT); + Op, DL.getPrefTypeAlignment(LI.getType()), DL, &LI, &AC, &DT); unsigned LoadAlign = LI.getAlignment(); unsigned EffectiveLoadAlign = LoadAlign != 0 ? LoadAlign : DL.getABITypeAlignment(LI.getType()); @@ -825,11 +849,10 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { // where there are several consecutive memory accesses to the same location, // separated by a few arithmetic operations. BasicBlock::iterator BBI(LI); - AAMDNodes AATags; bool IsLoadCSE = false; if (Value *AvailableVal = FindAvailableLoadedValue(&LI, LI.getParent(), BBI, - DefMaxInstsToScan, AA, &AATags, &IsLoadCSE)) { + DefMaxInstsToScan, AA, &IsLoadCSE)) { if (IsLoadCSE) { LoadInst *NLI = cast<LoadInst>(AvailableVal); unsigned KnownIDs[] = { @@ -1005,19 +1028,26 @@ static bool combineStoreToValueType(InstCombiner &IC, StoreInst &SI) { if (!SI.isUnordered()) return false; + // swifterror values can't be bitcasted. + if (SI.getPointerOperand()->isSwiftError()) + return false; + Value *V = SI.getValueOperand(); // Fold away bit casts of the stored value by storing the original type. if (auto *BC = dyn_cast<BitCastInst>(V)) { V = BC->getOperand(0); - combineStoreToNewValue(IC, SI, V); - return true; + if (!SI.isAtomic() || isSupportedAtomicType(V->getType())) { + combineStoreToNewValue(IC, SI, V); + return true; + } } - if (Value *U = likeBitCastFromVector(IC, V)) { - combineStoreToNewValue(IC, SI, U); - return true; - } + if (Value *U = likeBitCastFromVector(IC, V)) + if (!SI.isAtomic() || isSupportedAtomicType(U->getType())) { + combineStoreToNewValue(IC, SI, U); + return true; + } // FIXME: We should also canonicalize stores of vectors when their elements // are cast to other types. @@ -1169,7 +1199,7 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { // Attempt to improve the alignment. unsigned KnownAlign = getOrEnforceKnownAlignment( - Ptr, DL.getPrefTypeAlignment(Val->getType()), DL, &SI, AC, DT); + Ptr, DL.getPrefTypeAlignment(Val->getType()), DL, &SI, &AC, &DT); unsigned StoreAlign = SI.getAlignment(); unsigned EffectiveStoreAlign = StoreAlign != 0 ? StoreAlign : DL.getABITypeAlignment(Val->getType()); @@ -1293,7 +1323,7 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { assert(SI.isUnordered() && "this code has not been auditted for volatile or ordered store case"); - + BasicBlock *StoreBB = SI.getParent(); // Check to see if the successor block has exactly two incoming edges. If diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 788097f33f12..ac64671725f3 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -48,8 +48,8 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, BinaryOperator *I = dyn_cast<BinaryOperator>(V); if (I && I->isLogicalShift() && isKnownToBeAPowerOfTwo(I->getOperand(0), IC.getDataLayout(), false, 0, - IC.getAssumptionCache(), &CxtI, - IC.getDominatorTree())) { + &IC.getAssumptionCache(), &CxtI, + &IC.getDominatorTree())) { // We know that this is an exact/nuw shift and that the input is a // non-zero context as well. if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) { @@ -179,7 +179,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyMulInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifyMulInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); if (Value *V = SimplifyUsingDistributiveLaws(I)) @@ -389,6 +389,80 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } + // Check for (mul (sext x), y), see if we can merge this into an + // integer mul followed by a sext. + if (SExtInst *Op0Conv = dyn_cast<SExtInst>(Op0)) { + // (mul (sext x), cst) --> (sext (mul x, cst')) + if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { + if (Op0Conv->hasOneUse()) { + Constant *CI = + ConstantExpr::getTrunc(Op1C, Op0Conv->getOperand(0)->getType()); + if (ConstantExpr::getSExt(CI, I.getType()) == Op1C && + WillNotOverflowSignedMul(Op0Conv->getOperand(0), CI, I)) { + // Insert the new, smaller mul. + Value *NewMul = + Builder->CreateNSWMul(Op0Conv->getOperand(0), CI, "mulconv"); + return new SExtInst(NewMul, I.getType()); + } + } + } + + // (mul (sext x), (sext y)) --> (sext (mul int x, y)) + if (SExtInst *Op1Conv = dyn_cast<SExtInst>(Op1)) { + // Only do this if x/y have the same type, if at last one of them has a + // single use (so we don't increase the number of sexts), and if the + // integer mul will not overflow. + if (Op0Conv->getOperand(0)->getType() == + Op1Conv->getOperand(0)->getType() && + (Op0Conv->hasOneUse() || Op1Conv->hasOneUse()) && + WillNotOverflowSignedMul(Op0Conv->getOperand(0), + Op1Conv->getOperand(0), I)) { + // Insert the new integer mul. + Value *NewMul = Builder->CreateNSWMul( + Op0Conv->getOperand(0), Op1Conv->getOperand(0), "mulconv"); + return new SExtInst(NewMul, I.getType()); + } + } + } + + // Check for (mul (zext x), y), see if we can merge this into an + // integer mul followed by a zext. + if (auto *Op0Conv = dyn_cast<ZExtInst>(Op0)) { + // (mul (zext x), cst) --> (zext (mul x, cst')) + if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { + if (Op0Conv->hasOneUse()) { + Constant *CI = + ConstantExpr::getTrunc(Op1C, Op0Conv->getOperand(0)->getType()); + if (ConstantExpr::getZExt(CI, I.getType()) == Op1C && + computeOverflowForUnsignedMul(Op0Conv->getOperand(0), CI, &I) == + OverflowResult::NeverOverflows) { + // Insert the new, smaller mul. + Value *NewMul = + Builder->CreateNUWMul(Op0Conv->getOperand(0), CI, "mulconv"); + return new ZExtInst(NewMul, I.getType()); + } + } + } + + // (mul (zext x), (zext y)) --> (zext (mul int x, y)) + if (auto *Op1Conv = dyn_cast<ZExtInst>(Op1)) { + // Only do this if x/y have the same type, if at last one of them has a + // single use (so we don't increase the number of zexts), and if the + // integer mul will not overflow. + if (Op0Conv->getOperand(0)->getType() == + Op1Conv->getOperand(0)->getType() && + (Op0Conv->hasOneUse() || Op1Conv->hasOneUse()) && + computeOverflowForUnsignedMul(Op0Conv->getOperand(0), + Op1Conv->getOperand(0), + &I) == OverflowResult::NeverOverflows) { + // Insert the new integer mul. + Value *NewMul = Builder->CreateNUWMul( + Op0Conv->getOperand(0), Op1Conv->getOperand(0), "mulconv"); + return new ZExtInst(NewMul, I.getType()); + } + } + } + if (!I.hasNoSignedWrap() && WillNotOverflowSignedMul(Op0, Op1, I)) { Changed = true; I.setHasNoSignedWrap(true); @@ -545,7 +619,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { std::swap(Op0, Op1); if (Value *V = - SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) + SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); bool AllowReassociate = I.hasUnsafeAlgebra(); @@ -709,7 +783,6 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { BuilderTy::FastMathFlagGuard Guard(*Builder); Builder->setFastMathFlags(I.getFastMathFlags()); Value *T = Builder->CreateFMul(Opnd1, Opnd1); - Value *R = Builder->CreateFMul(T, Y); R->takeName(&I); return replaceInstUsesWith(I, R); @@ -991,19 +1064,22 @@ static Instruction *foldUDivNegCst(Value *Op0, Value *Op1, } // X udiv (C1 << N), where C1 is "1<<C2" --> X >> (N+C2) +// X udiv (zext (C1 << N)), where C1 is "1<<C2" --> X >> (N+C2) static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I, InstCombiner &IC) { - Instruction *ShiftLeft = cast<Instruction>(Op1); - if (isa<ZExtInst>(ShiftLeft)) - ShiftLeft = cast<Instruction>(ShiftLeft->getOperand(0)); + Value *ShiftLeft; + if (!match(Op1, m_ZExt(m_Value(ShiftLeft)))) + ShiftLeft = Op1; - const APInt &CI = - cast<Constant>(ShiftLeft->getOperand(0))->getUniqueInteger(); - Value *N = ShiftLeft->getOperand(1); - if (CI != 1) - N = IC.Builder->CreateAdd(N, ConstantInt::get(N->getType(), CI.logBase2())); - if (ZExtInst *Z = dyn_cast<ZExtInst>(Op1)) - N = IC.Builder->CreateZExt(N, Z->getDestTy()); + const APInt *CI; + Value *N; + if (!match(ShiftLeft, m_Shl(m_APInt(CI), m_Value(N)))) + llvm_unreachable("match should never fail here!"); + if (*CI != 1) + N = IC.Builder->CreateAdd(N, + ConstantInt::get(N->getType(), CI->logBase2())); + if (Op1 != ShiftLeft) + N = IC.Builder->CreateZExt(N, Op1->getType()); BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, N); if (I.isExact()) LShr->setIsExact(); @@ -1059,7 +1135,7 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyUDivInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifyUDivInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // Handle the integer div common cases @@ -1132,7 +1208,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifySDivInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifySDivInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // Handle the integer div common cases @@ -1195,7 +1271,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { return BO; } - if (isKnownToBeAPowerOfTwo(Op1, DL, /*OrZero*/ true, 0, AC, &I, DT)) { + if (isKnownToBeAPowerOfTwo(Op1, DL, /*OrZero*/ true, 0, &AC, &I, &DT)) { // X sdiv (1 << Y) -> X udiv (1 << Y) ( -> X u>> Y) // Safe because the only negative value (1 << Y) can take on is // INT_MIN, and X sdiv INT_MIN == X udiv INT_MIN == 0 if X doesn't have @@ -1247,7 +1323,7 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = SimplifyFDivInst(Op0, Op1, I.getFastMathFlags(), - DL, TLI, DT, AC)) + DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); if (isa<Constant>(Op0)) @@ -1421,7 +1497,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyURemInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifyURemInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); if (Instruction *common = commonIRemTransforms(I)) @@ -1434,7 +1510,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { I.getType()); // X urem Y -> X and Y-1, where Y is a power of 2, - if (isKnownToBeAPowerOfTwo(Op1, DL, /*OrZero*/ true, 0, AC, &I, DT)) { + if (isKnownToBeAPowerOfTwo(Op1, DL, /*OrZero*/ true, 0, &AC, &I, &DT)) { Constant *N1 = Constant::getAllOnesValue(I.getType()); Value *Add = Builder->CreateAdd(Op1, N1); return BinaryOperator::CreateAnd(Op0, Add); @@ -1447,6 +1523,14 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { return replaceInstUsesWith(I, Ext); } + // X urem C -> X < C ? X : X - C, where C >= signbit. + const APInt *DivisorC; + if (match(Op1, m_APInt(DivisorC)) && DivisorC->isNegative()) { + Value *Cmp = Builder->CreateICmpULT(Op0, Op1); + Value *Sub = Builder->CreateSub(Op0, Op1); + return SelectInst::Create(Cmp, Op0, Sub); + } + return nullptr; } @@ -1456,7 +1540,7 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifySRemInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifySRemInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // Handle the integer rem common cases @@ -1532,7 +1616,7 @@ Instruction *InstCombiner::visitFRem(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = SimplifyFRemInst(Op0, Op1, I.getFastMathFlags(), - DL, TLI, DT, AC)) + DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // Handle cases involving: rem X, (select Cond, Y, Z) diff --git a/lib/Transforms/InstCombine/InstCombinePHI.cpp b/lib/Transforms/InstCombine/InstCombinePHI.cpp index 79a4912332ff..184897f751fe 100644 --- a/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -18,11 +18,27 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/IR/DebugInfo.h" using namespace llvm; using namespace llvm::PatternMatch; #define DEBUG_TYPE "instcombine" +/// The PHI arguments will be folded into a single operation with a PHI node +/// as input. The debug location of the single operation will be the merged +/// locations of the original PHI node arguments. +DebugLoc InstCombiner::PHIArgMergedDebugLoc(PHINode &PN) { + auto *FirstInst = cast<Instruction>(PN.getIncomingValue(0)); + DILocation *Loc = FirstInst->getDebugLoc(); + + for (unsigned i = 1; i != PN.getNumIncomingValues(); ++i) { + auto *I = cast<Instruction>(PN.getIncomingValue(i)); + Loc = DILocation::getMergedLocation(Loc, I->getDebugLoc()); + } + + return Loc; +} + /// If we have something like phi [add (a,b), add(a,c)] and if a/b/c and the /// adds all have a single use, turn this into a phi and a single binop. Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { @@ -101,7 +117,7 @@ Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { if (CmpInst *CIOp = dyn_cast<CmpInst>(FirstInst)) { CmpInst *NewCI = CmpInst::Create(CIOp->getOpcode(), CIOp->getPredicate(), LHSVal, RHSVal); - NewCI->setDebugLoc(FirstInst->getDebugLoc()); + NewCI->setDebugLoc(PHIArgMergedDebugLoc(PN)); return NewCI; } @@ -114,7 +130,7 @@ Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) NewBinOp->andIRFlags(PN.getIncomingValue(i)); - NewBinOp->setDebugLoc(FirstInst->getDebugLoc()); + NewBinOp->setDebugLoc(PHIArgMergedDebugLoc(PN)); return NewBinOp; } @@ -223,7 +239,7 @@ Instruction *InstCombiner::FoldPHIArgGEPIntoPHI(PHINode &PN) { GetElementPtrInst::Create(FirstInst->getSourceElementType(), Base, makeArrayRef(FixedOperands).slice(1)); if (AllInBounds) NewGEP->setIsInBounds(); - NewGEP->setDebugLoc(FirstInst->getDebugLoc()); + NewGEP->setDebugLoc(PHIArgMergedDebugLoc(PN)); return NewGEP; } @@ -383,7 +399,7 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) { for (Value *IncValue : PN.incoming_values()) cast<LoadInst>(IncValue)->setVolatile(false); - NewLI->setDebugLoc(FirstLI->getDebugLoc()); + NewLI->setDebugLoc(PHIArgMergedDebugLoc(PN)); return NewLI; } @@ -549,7 +565,7 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { if (CastInst *FirstCI = dyn_cast<CastInst>(FirstInst)) { CastInst *NewCI = CastInst::Create(FirstCI->getOpcode(), PhiVal, PN.getType()); - NewCI->setDebugLoc(FirstInst->getDebugLoc()); + NewCI->setDebugLoc(PHIArgMergedDebugLoc(PN)); return NewCI; } @@ -560,14 +576,14 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) BinOp->andIRFlags(PN.getIncomingValue(i)); - BinOp->setDebugLoc(FirstInst->getDebugLoc()); + BinOp->setDebugLoc(PHIArgMergedDebugLoc(PN)); return BinOp; } CmpInst *CIOp = cast<CmpInst>(FirstInst); CmpInst *NewCI = CmpInst::Create(CIOp->getOpcode(), CIOp->getPredicate(), PhiVal, ConstantOp); - NewCI->setDebugLoc(FirstInst->getDebugLoc()); + NewCI->setDebugLoc(PHIArgMergedDebugLoc(PN)); return NewCI; } @@ -835,8 +851,8 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { // needed piece. if (PHINode *OldInVal = dyn_cast<PHINode>(PN->getIncomingValue(i))) if (PHIsInspected.count(OldInVal)) { - unsigned RefPHIId = std::find(PHIsToSlice.begin(),PHIsToSlice.end(), - OldInVal)-PHIsToSlice.begin(); + unsigned RefPHIId = + find(PHIsToSlice, OldInVal) - PHIsToSlice.begin(); PHIUsers.push_back(PHIUsageRecord(RefPHIId, Offset, cast<Instruction>(Res))); ++UserE; @@ -864,7 +880,7 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { // PHINode simplification // Instruction *InstCombiner::visitPHINode(PHINode &PN) { - if (Value *V = SimplifyInstruction(&PN, DL, TLI, DT, AC)) + if (Value *V = SimplifyInstruction(&PN, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(PN, V); if (Instruction *Result = FoldPHIArgZextsIntoPHI(PN)) @@ -921,7 +937,7 @@ Instruction *InstCombiner::visitPHINode(PHINode &PN) { for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { Instruction *CtxI = PN.getIncomingBlock(i)->getTerminator(); Value *VA = PN.getIncomingValue(i); - if (isKnownNonZero(VA, DL, 0, AC, CtxI, DT)) { + if (isKnownNonZero(VA, DL, 0, &AC, CtxI, &DT)) { if (!NonZeroConst) NonZeroConst = GetAnyNonZeroConstInt(PN); PN.setIncomingValue(i, NonZeroConst); diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 8f1ff8ac0e66..36644845352e 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -15,6 +15,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/PatternMatch.h" using namespace llvm; using namespace PatternMatch; @@ -78,7 +79,7 @@ static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy *Builder, /// a bitmask indicating which operands of this instruction are foldable if they /// equal the other incoming value of the select. /// -static unsigned GetSelectFoldableOperands(Instruction *I) { +static unsigned getSelectFoldableOperands(Instruction *I) { switch (I->getOpcode()) { case Instruction::Add: case Instruction::Mul: @@ -98,7 +99,7 @@ static unsigned GetSelectFoldableOperands(Instruction *I) { /// For the same transformation as the previous function, return the identity /// constant that goes into the select. -static Constant *GetSelectFoldableConstant(Instruction *I) { +static Constant *getSelectFoldableConstant(Instruction *I) { switch (I->getOpcode()) { default: llvm_unreachable("This cannot happen!"); case Instruction::Add: @@ -117,7 +118,7 @@ static Constant *GetSelectFoldableConstant(Instruction *I) { } /// We have (select c, TI, FI), and we know that TI and FI have the same opcode. -Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI, +Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI) { // If this is a cast from the same type, merge. if (TI->getNumOperands() == 1 && TI->isCast()) { @@ -154,19 +155,19 @@ Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI, } // Fold this by inserting a select from the input values. - Value *NewSI = Builder->CreateSelect(SI.getCondition(), TI->getOperand(0), - FI->getOperand(0), SI.getName()+".v"); + Value *NewSI = + Builder->CreateSelect(SI.getCondition(), TI->getOperand(0), + FI->getOperand(0), SI.getName() + ".v", &SI); return CastInst::Create(Instruction::CastOps(TI->getOpcode()), NewSI, TI->getType()); } - // TODO: This function ends awkwardly in unreachable - fix to be more normal. - // Only handle binary operators with one-use here. As with the cast case // above, it may be possible to relax the one-use constraint, but that needs // be examined carefully since it may not reduce the total number of // instructions. - if (!isa<BinaryOperator>(TI) || !TI->hasOneUse() || !FI->hasOneUse()) + BinaryOperator *BO = dyn_cast<BinaryOperator>(TI); + if (!BO || !TI->hasOneUse() || !FI->hasOneUse()) return nullptr; // Figure out if the operations have any operands in common. @@ -199,16 +200,11 @@ Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI, } // If we reach here, they do have operations in common. - Value *NewSI = Builder->CreateSelect(SI.getCondition(), OtherOpT, - OtherOpF, SI.getName()+".v"); - - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(TI)) { - if (MatchIsOpZero) - return BinaryOperator::Create(BO->getOpcode(), MatchOp, NewSI); - else - return BinaryOperator::Create(BO->getOpcode(), NewSI, MatchOp); - } - llvm_unreachable("Shouldn't get here"); + Value *NewSI = Builder->CreateSelect(SI.getCondition(), OtherOpT, OtherOpF, + SI.getName() + ".v", &SI); + Value *Op0 = MatchIsOpZero ? MatchOp : NewSI; + Value *Op1 = MatchIsOpZero ? NewSI : MatchOp; + return BinaryOperator::Create(BO->getOpcode(), Op0, Op1); } static bool isSelect01(Constant *C1, Constant *C2) { @@ -226,14 +222,14 @@ static bool isSelect01(Constant *C1, Constant *C2) { /// Try to fold the select into one of the operands to allow further /// optimization. -Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, +Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, Value *FalseVal) { // See the comment above GetSelectFoldableOperands for a description of the // transformation we are doing here. if (Instruction *TVI = dyn_cast<Instruction>(TrueVal)) { if (TVI->hasOneUse() && TVI->getNumOperands() == 2 && !isa<Constant>(FalseVal)) { - if (unsigned SFO = GetSelectFoldableOperands(TVI)) { + if (unsigned SFO = getSelectFoldableOperands(TVI)) { unsigned OpToFold = 0; if ((SFO & 1) && FalseVal == TVI->getOperand(0)) { OpToFold = 1; @@ -242,7 +238,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, } if (OpToFold) { - Constant *C = GetSelectFoldableConstant(TVI); + Constant *C = getSelectFoldableConstant(TVI); Value *OOp = TVI->getOperand(2-OpToFold); // Avoid creating select between 2 constants unless it's selecting // between 0, 1 and -1. @@ -263,7 +259,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, if (Instruction *FVI = dyn_cast<Instruction>(FalseVal)) { if (FVI->hasOneUse() && FVI->getNumOperands() == 2 && !isa<Constant>(TrueVal)) { - if (unsigned SFO = GetSelectFoldableOperands(FVI)) { + if (unsigned SFO = getSelectFoldableOperands(FVI)) { unsigned OpToFold = 0; if ((SFO & 1) && TrueVal == FVI->getOperand(0)) { OpToFold = 1; @@ -272,7 +268,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, } if (OpToFold) { - Constant *C = GetSelectFoldableConstant(FVI); + Constant *C = getSelectFoldableConstant(FVI); Value *OOp = FVI->getOperand(2-OpToFold); // Avoid creating select between 2 constants unless it's selecting // between 0, 1 and -1. @@ -411,102 +407,150 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, return nullptr; } -/// Visit a SelectInst that has an ICmpInst as its first operand. -Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, - ICmpInst *ICI) { - bool Changed = false; - ICmpInst::Predicate Pred = ICI->getPredicate(); - Value *CmpLHS = ICI->getOperand(0); - Value *CmpRHS = ICI->getOperand(1); - Value *TrueVal = SI.getTrueValue(); - Value *FalseVal = SI.getFalseValue(); +/// Return true if we find and adjust an icmp+select pattern where the compare +/// is with a constant that can be incremented or decremented to match the +/// minimum or maximum idiom. +static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *CmpLHS = Cmp.getOperand(0); + Value *CmpRHS = Cmp.getOperand(1); + Value *TrueVal = Sel.getTrueValue(); + Value *FalseVal = Sel.getFalseValue(); - // Check cases where the comparison is with a constant that - // can be adjusted to fit the min/max idiom. We may move or edit ICI - // here, so make sure the select is the only user. - if (ICI->hasOneUse()) - if (ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS)) { - switch (Pred) { - default: break; - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_SGT: { - // These transformations only work for selects over integers. - IntegerType *SelectTy = dyn_cast<IntegerType>(SI.getType()); - if (!SelectTy) - break; + // We may move or edit the compare, so make sure the select is the only user. + const APInt *CmpC; + if (!Cmp.hasOneUse() || !match(CmpRHS, m_APInt(CmpC))) + return false; - Constant *AdjustedRHS; - if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_SGT) - AdjustedRHS = ConstantInt::get(CI->getContext(), CI->getValue() + 1); - else // (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) - AdjustedRHS = ConstantInt::get(CI->getContext(), CI->getValue() - 1); + // These transforms only work for selects of integers or vector selects of + // integer vectors. + Type *SelTy = Sel.getType(); + auto *SelEltTy = dyn_cast<IntegerType>(SelTy->getScalarType()); + if (!SelEltTy || SelTy->isVectorTy() != Cmp.getType()->isVectorTy()) + return false; - // X > C ? X : C+1 --> X < C+1 ? C+1 : X - // X < C ? X : C-1 --> X > C-1 ? C-1 : X - if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) || - (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) - ; // Nothing to do here. Values match without any sign/zero extension. + Constant *AdjustedRHS; + if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_SGT) + AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC + 1); + else if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) + AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC - 1); + else + return false; - // Types do not match. Instead of calculating this with mixed types - // promote all to the larger type. This enables scalar evolution to - // analyze this expression. - else if (CmpRHS->getType()->getScalarSizeInBits() - < SelectTy->getBitWidth()) { - Constant *sextRHS = ConstantExpr::getSExt(AdjustedRHS, SelectTy); + // X > C ? X : C+1 --> X < C+1 ? C+1 : X + // X < C ? X : C-1 --> X > C-1 ? C-1 : X + if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) || + (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) { + ; // Nothing to do here. Values match without any sign/zero extension. + } + // Types do not match. Instead of calculating this with mixed types, promote + // all to the larger type. This enables scalar evolution to analyze this + // expression. + else if (CmpRHS->getType()->getScalarSizeInBits() < SelEltTy->getBitWidth()) { + Constant *SextRHS = ConstantExpr::getSExt(AdjustedRHS, SelTy); - // X = sext x; x >s c ? X : C+1 --> X = sext x; X <s C+1 ? C+1 : X - // X = sext x; x <s c ? X : C-1 --> X = sext x; X >s C-1 ? C-1 : X - // X = sext x; x >u c ? X : C+1 --> X = sext x; X <u C+1 ? C+1 : X - // X = sext x; x <u c ? X : C-1 --> X = sext x; X >u C-1 ? C-1 : X - if (match(TrueVal, m_SExt(m_Specific(CmpLHS))) && - sextRHS == FalseVal) { - CmpLHS = TrueVal; - AdjustedRHS = sextRHS; - } else if (match(FalseVal, m_SExt(m_Specific(CmpLHS))) && - sextRHS == TrueVal) { - CmpLHS = FalseVal; - AdjustedRHS = sextRHS; - } else if (ICI->isUnsigned()) { - Constant *zextRHS = ConstantExpr::getZExt(AdjustedRHS, SelectTy); - // X = zext x; x >u c ? X : C+1 --> X = zext x; X <u C+1 ? C+1 : X - // X = zext x; x <u c ? X : C-1 --> X = zext x; X >u C-1 ? C-1 : X - // zext + signed compare cannot be changed: - // 0xff <s 0x00, but 0x00ff >s 0x0000 - if (match(TrueVal, m_ZExt(m_Specific(CmpLHS))) && - zextRHS == FalseVal) { - CmpLHS = TrueVal; - AdjustedRHS = zextRHS; - } else if (match(FalseVal, m_ZExt(m_Specific(CmpLHS))) && - zextRHS == TrueVal) { - CmpLHS = FalseVal; - AdjustedRHS = zextRHS; - } else - break; - } else - break; - } else - break; + // X = sext x; x >s c ? X : C+1 --> X = sext x; X <s C+1 ? C+1 : X + // X = sext x; x <s c ? X : C-1 --> X = sext x; X >s C-1 ? C-1 : X + // X = sext x; x >u c ? X : C+1 --> X = sext x; X <u C+1 ? C+1 : X + // X = sext x; x <u c ? X : C-1 --> X = sext x; X >u C-1 ? C-1 : X + if (match(TrueVal, m_SExt(m_Specific(CmpLHS))) && SextRHS == FalseVal) { + CmpLHS = TrueVal; + AdjustedRHS = SextRHS; + } else if (match(FalseVal, m_SExt(m_Specific(CmpLHS))) && + SextRHS == TrueVal) { + CmpLHS = FalseVal; + AdjustedRHS = SextRHS; + } else if (Cmp.isUnsigned()) { + Constant *ZextRHS = ConstantExpr::getZExt(AdjustedRHS, SelTy); + // X = zext x; x >u c ? X : C+1 --> X = zext x; X <u C+1 ? C+1 : X + // X = zext x; x <u c ? X : C-1 --> X = zext x; X >u C-1 ? C-1 : X + // zext + signed compare cannot be changed: + // 0xff <s 0x00, but 0x00ff >s 0x0000 + if (match(TrueVal, m_ZExt(m_Specific(CmpLHS))) && ZextRHS == FalseVal) { + CmpLHS = TrueVal; + AdjustedRHS = ZextRHS; + } else if (match(FalseVal, m_ZExt(m_Specific(CmpLHS))) && + ZextRHS == TrueVal) { + CmpLHS = FalseVal; + AdjustedRHS = ZextRHS; + } else { + return false; + } + } else { + return false; + } + } else { + return false; + } - Pred = ICmpInst::getSwappedPredicate(Pred); - CmpRHS = AdjustedRHS; - std::swap(FalseVal, TrueVal); - ICI->setPredicate(Pred); - ICI->setOperand(0, CmpLHS); - ICI->setOperand(1, CmpRHS); - SI.setOperand(1, TrueVal); - SI.setOperand(2, FalseVal); + Pred = ICmpInst::getSwappedPredicate(Pred); + CmpRHS = AdjustedRHS; + std::swap(FalseVal, TrueVal); + Cmp.setPredicate(Pred); + Cmp.setOperand(0, CmpLHS); + Cmp.setOperand(1, CmpRHS); + Sel.setOperand(1, TrueVal); + Sel.setOperand(2, FalseVal); + Sel.swapProfMetadata(); - // Move ICI instruction right before the select instruction. Otherwise - // the sext/zext value may be defined after the ICI instruction uses it. - ICI->moveBefore(&SI); + // Move the compare instruction right before the select instruction. Otherwise + // the sext/zext value may be defined after the compare instruction uses it. + Cmp.moveBefore(&Sel); - Changed = true; - break; - } - } - } + return true; +} + +/// If this is an integer min/max where the select's 'true' operand is a +/// constant, canonicalize that constant to the 'false' operand: +/// select (icmp Pred X, C), C, X --> select (icmp Pred' X, C), X, C +static Instruction * +canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, + InstCombiner::BuilderTy &Builder) { + // TODO: We should also canonicalize min/max when the select has a different + // constant value than the cmp constant, but we need to fix the backend first. + if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1)) || + !isa<Constant>(Sel.getTrueValue()) || + isa<Constant>(Sel.getFalseValue()) || + Cmp.getOperand(1) != Sel.getTrueValue()) + return nullptr; + + // Canonicalize the compare predicate based on whether we have min or max. + Value *LHS, *RHS; + ICmpInst::Predicate NewPred; + SelectPatternResult SPR = matchSelectPattern(&Sel, LHS, RHS); + switch (SPR.Flavor) { + case SPF_SMIN: NewPred = ICmpInst::ICMP_SLT; break; + case SPF_UMIN: NewPred = ICmpInst::ICMP_ULT; break; + case SPF_SMAX: NewPred = ICmpInst::ICMP_SGT; break; + case SPF_UMAX: NewPred = ICmpInst::ICMP_UGT; break; + default: return nullptr; + } + + // Canonicalize the constant to the right side. + if (isa<Constant>(LHS)) + std::swap(LHS, RHS); + + Value *NewCmp = Builder.CreateICmp(NewPred, LHS, RHS); + SelectInst *NewSel = SelectInst::Create(NewCmp, LHS, RHS, "", nullptr, &Sel); + + // We swapped the select operands, so swap the metadata too. + NewSel->swapProfMetadata(); + return NewSel; +} + +/// Visit a SelectInst that has an ICmpInst as its first operand. +Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, + ICmpInst *ICI) { + if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, *Builder)) + return NewSel; + + bool Changed = adjustMinMax(SI, *ICI); + + ICmpInst::Predicate Pred = ICI->getPredicate(); + Value *CmpLHS = ICI->getOperand(0); + Value *CmpRHS = ICI->getOperand(1); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); // Transform (X >s -1) ? C1 : C2 --> ((X >>s 31) & (C2 - C1)) + C1 // and (X <s 0) ? C2 : C1 --> ((X >>s 31) & (C2 - C1)) + C1 @@ -623,7 +667,7 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, /// /// because Y is not live in BB1/BB2. /// -static bool CanSelectOperandBeMappingIntoPredBlock(const Value *V, +static bool canSelectOperandBeMappingIntoPredBlock(const Value *V, const SelectInst &SI) { // If the value is a non-instruction value like a constant or argument, it // can always be mapped. @@ -651,7 +695,7 @@ static bool CanSelectOperandBeMappingIntoPredBlock(const Value *V, /// We have an SPF (e.g. a min or max) of an SPF of the form: /// SPF2(SPF1(A, B), C) -Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, +Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, SelectPatternFlavor SPF1, Value *A, Value *B, Instruction &Outer, @@ -675,28 +719,24 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, } if (SPF1 == SPF2) { - if (ConstantInt *CB = dyn_cast<ConstantInt>(B)) { - if (ConstantInt *CC = dyn_cast<ConstantInt>(C)) { - const APInt &ACB = CB->getValue(); - const APInt &ACC = CC->getValue(); + const APInt *CB, *CC; + if (match(B, m_APInt(CB)) && match(C, m_APInt(CC))) { + // MIN(MIN(A, 23), 97) -> MIN(A, 23) + // MAX(MAX(A, 97), 23) -> MAX(A, 97) + if ((SPF1 == SPF_UMIN && CB->ule(*CC)) || + (SPF1 == SPF_SMIN && CB->sle(*CC)) || + (SPF1 == SPF_UMAX && CB->uge(*CC)) || + (SPF1 == SPF_SMAX && CB->sge(*CC))) + return replaceInstUsesWith(Outer, Inner); - // MIN(MIN(A, 23), 97) -> MIN(A, 23) - // MAX(MAX(A, 97), 23) -> MAX(A, 97) - if ((SPF1 == SPF_UMIN && ACB.ule(ACC)) || - (SPF1 == SPF_SMIN && ACB.sle(ACC)) || - (SPF1 == SPF_UMAX && ACB.uge(ACC)) || - (SPF1 == SPF_SMAX && ACB.sge(ACC))) - return replaceInstUsesWith(Outer, Inner); - - // MIN(MIN(A, 97), 23) -> MIN(A, 23) - // MAX(MAX(A, 23), 97) -> MAX(A, 97) - if ((SPF1 == SPF_UMIN && ACB.ugt(ACC)) || - (SPF1 == SPF_SMIN && ACB.sgt(ACC)) || - (SPF1 == SPF_UMAX && ACB.ult(ACC)) || - (SPF1 == SPF_SMAX && ACB.slt(ACC))) { - Outer.replaceUsesOfWith(Inner, A); - return &Outer; - } + // MIN(MIN(A, 97), 23) -> MIN(A, 23) + // MAX(MAX(A, 23), 97) -> MAX(A, 97) + if ((SPF1 == SPF_UMIN && CB->ugt(*CC)) || + (SPF1 == SPF_SMIN && CB->sgt(*CC)) || + (SPF1 == SPF_UMAX && CB->ult(*CC)) || + (SPF1 == SPF_SMAX && CB->slt(*CC))) { + Outer.replaceUsesOfWith(Inner, A); + return &Outer; } } } @@ -712,8 +752,9 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, if ((SPF1 == SPF_ABS && SPF2 == SPF_NABS) || (SPF1 == SPF_NABS && SPF2 == SPF_ABS)) { SelectInst *SI = cast<SelectInst>(Inner); - Value *NewSI = Builder->CreateSelect( - SI->getCondition(), SI->getFalseValue(), SI->getTrueValue()); + Value *NewSI = + Builder->CreateSelect(SI->getCondition(), SI->getFalseValue(), + SI->getTrueValue(), SI->getName(), SI); return replaceInstUsesWith(Outer, NewSI); } @@ -895,7 +936,7 @@ static Instruction *foldAddSubSelect(SelectInst &SI, if (AddOp != TI) std::swap(NewTrueOp, NewFalseOp); Value *NewSel = Builder.CreateSelect(CondVal, NewTrueOp, NewFalseOp, - SI.getName() + ".p"); + SI.getName() + ".p", &SI); if (SI.getType()->isFPOrFPVectorTy()) { Instruction *RI = @@ -912,6 +953,147 @@ static Instruction *foldAddSubSelect(SelectInst &SI, return nullptr; } +Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) { + Instruction *ExtInst; + if (!match(Sel.getTrueValue(), m_Instruction(ExtInst)) && + !match(Sel.getFalseValue(), m_Instruction(ExtInst))) + return nullptr; + + auto ExtOpcode = ExtInst->getOpcode(); + if (ExtOpcode != Instruction::ZExt && ExtOpcode != Instruction::SExt) + return nullptr; + + // TODO: Handle larger types? That requires adjusting FoldOpIntoSelect too. + Value *X = ExtInst->getOperand(0); + Type *SmallType = X->getType(); + if (!SmallType->getScalarType()->isIntegerTy(1)) + return nullptr; + + Constant *C; + if (!match(Sel.getTrueValue(), m_Constant(C)) && + !match(Sel.getFalseValue(), m_Constant(C))) + return nullptr; + + // If the constant is the same after truncation to the smaller type and + // extension to the original type, we can narrow the select. + Value *Cond = Sel.getCondition(); + Type *SelType = Sel.getType(); + Constant *TruncC = ConstantExpr::getTrunc(C, SmallType); + Constant *ExtC = ConstantExpr::getCast(ExtOpcode, TruncC, SelType); + if (ExtC == C) { + Value *TruncCVal = cast<Value>(TruncC); + if (ExtInst == Sel.getFalseValue()) + std::swap(X, TruncCVal); + + // select Cond, (ext X), C --> ext(select Cond, X, C') + // select Cond, C, (ext X) --> ext(select Cond, C', X) + Value *NewSel = Builder->CreateSelect(Cond, X, TruncCVal, "narrow", &Sel); + return CastInst::Create(Instruction::CastOps(ExtOpcode), NewSel, SelType); + } + + // If one arm of the select is the extend of the condition, replace that arm + // with the extension of the appropriate known bool value. + if (Cond == X) { + if (ExtInst == Sel.getTrueValue()) { + // select X, (sext X), C --> select X, -1, C + // select X, (zext X), C --> select X, 1, C + Constant *One = ConstantInt::getTrue(SmallType); + Constant *AllOnesOrOne = ConstantExpr::getCast(ExtOpcode, One, SelType); + return SelectInst::Create(Cond, AllOnesOrOne, C, "", nullptr, &Sel); + } else { + // select X, C, (sext X) --> select X, C, 0 + // select X, C, (zext X) --> select X, C, 0 + Constant *Zero = ConstantInt::getNullValue(SelType); + return SelectInst::Create(Cond, C, Zero, "", nullptr, &Sel); + } + } + + return nullptr; +} + +/// Try to transform a vector select with a constant condition vector into a +/// shuffle for easier combining with other shuffles and insert/extract. +static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) { + Value *CondVal = SI.getCondition(); + Constant *CondC; + if (!CondVal->getType()->isVectorTy() || !match(CondVal, m_Constant(CondC))) + return nullptr; + + unsigned NumElts = CondVal->getType()->getVectorNumElements(); + SmallVector<Constant *, 16> Mask; + Mask.reserve(NumElts); + Type *Int32Ty = Type::getInt32Ty(CondVal->getContext()); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = CondC->getAggregateElement(i); + if (!Elt) + return nullptr; + + if (Elt->isOneValue()) { + // If the select condition element is true, choose from the 1st vector. + Mask.push_back(ConstantInt::get(Int32Ty, i)); + } else if (Elt->isNullValue()) { + // If the select condition element is false, choose from the 2nd vector. + Mask.push_back(ConstantInt::get(Int32Ty, i + NumElts)); + } else if (isa<UndefValue>(Elt)) { + // If the select condition element is undef, the shuffle mask is undef. + Mask.push_back(UndefValue::get(Int32Ty)); + } else { + // Bail out on a constant expression. + return nullptr; + } + } + + return new ShuffleVectorInst(SI.getTrueValue(), SI.getFalseValue(), + ConstantVector::get(Mask)); +} + +/// Reuse bitcasted operands between a compare and select: +/// select (cmp (bitcast C), (bitcast D)), (bitcast' C), (bitcast' D) --> +/// bitcast (select (cmp (bitcast C), (bitcast D)), (bitcast C), (bitcast D)) +static Instruction *foldSelectCmpBitcasts(SelectInst &Sel, + InstCombiner::BuilderTy &Builder) { + Value *Cond = Sel.getCondition(); + Value *TVal = Sel.getTrueValue(); + Value *FVal = Sel.getFalseValue(); + + CmpInst::Predicate Pred; + Value *A, *B; + if (!match(Cond, m_Cmp(Pred, m_Value(A), m_Value(B)))) + return nullptr; + + // The select condition is a compare instruction. If the select's true/false + // values are already the same as the compare operands, there's nothing to do. + if (TVal == A || TVal == B || FVal == A || FVal == B) + return nullptr; + + Value *C, *D; + if (!match(A, m_BitCast(m_Value(C))) || !match(B, m_BitCast(m_Value(D)))) + return nullptr; + + // select (cmp (bitcast C), (bitcast D)), (bitcast TSrc), (bitcast FSrc) + Value *TSrc, *FSrc; + if (!match(TVal, m_BitCast(m_Value(TSrc))) || + !match(FVal, m_BitCast(m_Value(FSrc)))) + return nullptr; + + // If the select true/false values are *different bitcasts* of the same source + // operands, make the select operands the same as the compare operands and + // cast the result. This is the canonical select form for min/max. + Value *NewSel; + if (TSrc == C && FSrc == D) { + // select (cmp (bitcast C), (bitcast D)), (bitcast' C), (bitcast' D) --> + // bitcast (select (cmp A, B), A, B) + NewSel = Builder.CreateSelect(Cond, A, B, "", &Sel); + } else if (TSrc == D && FSrc == C) { + // select (cmp (bitcast C), (bitcast D)), (bitcast' D), (bitcast' C) --> + // bitcast (select (cmp A, B), B, A) + NewSel = Builder.CreateSelect(Cond, B, A, "", &Sel); + } else { + return nullptr; + } + return CastInst::CreateBitOrPointerCast(NewSel, Sel.getType()); +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -919,9 +1101,12 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Type *SelType = SI.getType(); if (Value *V = - SimplifySelectInst(CondVal, TrueVal, FalseVal, DL, TLI, DT, AC)) + SimplifySelectInst(CondVal, TrueVal, FalseVal, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(SI, V); + if (Instruction *I = canonicalizeSelectToShuffle(SI)) + return I; + if (SelType->getScalarType()->isIntegerTy(1) && TrueVal->getType() == CondVal->getType()) { if (match(TrueVal, m_One())) { @@ -1085,7 +1270,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // See if we are selecting two values based on a comparison of the two values. if (ICmpInst *ICI = dyn_cast<ICmpInst>(CondVal)) - if (Instruction *Result = visitSelectInstWithICmp(SI, ICI)) + if (Instruction *Result = foldSelectInstWithICmp(SI, ICI)) return Result; if (Instruction *Add = foldAddSubSelect(SI, *Builder)) @@ -1095,12 +1280,15 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { auto *TI = dyn_cast<Instruction>(TrueVal); auto *FI = dyn_cast<Instruction>(FalseVal); if (TI && FI && TI->getOpcode() == FI->getOpcode()) - if (Instruction *IV = FoldSelectOpOp(SI, TI, FI)) + if (Instruction *IV = foldSelectOpOp(SI, TI, FI)) return IV; + if (Instruction *I = foldSelectExtConst(SI)) + return I; + // See if we can fold the select into one of our operands. if (SelType->isIntOrIntVectorTy() || SelType->isFPOrFPVectorTy()) { - if (Instruction *FoldI = FoldSelectIntoOp(SI, TrueVal, FalseVal)) + if (Instruction *FoldI = foldSelectIntoOp(SI, TrueVal, FalseVal)) return FoldI; Value *LHS, *RHS, *LHS2, *RHS2; @@ -1124,9 +1312,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Cmp = Builder->CreateFCmp(Pred, LHS, RHS); } - Value *NewSI = Builder->CreateCast(CastOp, - Builder->CreateSelect(Cmp, LHS, RHS), - SelType); + Value *NewSI = Builder->CreateCast( + CastOp, Builder->CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI), + SelType); return replaceInstUsesWith(SI, NewSI); } } @@ -1139,39 +1327,35 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // ABS(ABS(a)) -> ABS(a) // NABS(NABS(a)) -> NABS(a) if (SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor) - if (Instruction *R = FoldSPFofSPF(cast<Instruction>(LHS),SPF2,LHS2,RHS2, + if (Instruction *R = foldSPFofSPF(cast<Instruction>(LHS),SPF2,LHS2,RHS2, SI, SPF, RHS)) return R; if (SelectPatternFlavor SPF2 = matchSelectPattern(RHS, LHS2, RHS2).Flavor) - if (Instruction *R = FoldSPFofSPF(cast<Instruction>(RHS),SPF2,LHS2,RHS2, + if (Instruction *R = foldSPFofSPF(cast<Instruction>(RHS),SPF2,LHS2,RHS2, SI, SPF, LHS)) return R; } // MAX(~a, ~b) -> ~MIN(a, b) - if (SPF == SPF_SMAX || SPF == SPF_UMAX) { - if (IsFreeToInvert(LHS, LHS->hasNUses(2)) && - IsFreeToInvert(RHS, RHS->hasNUses(2))) { + if ((SPF == SPF_SMAX || SPF == SPF_UMAX) && + IsFreeToInvert(LHS, LHS->hasNUses(2)) && + IsFreeToInvert(RHS, RHS->hasNUses(2))) { + // For this transform to be profitable, we need to eliminate at least two + // 'not' instructions if we're going to add one 'not' instruction. + int NumberOfNots = + (LHS->hasNUses(2) && match(LHS, m_Not(m_Value()))) + + (RHS->hasNUses(2) && match(RHS, m_Not(m_Value()))) + + (SI.hasOneUse() && match(*SI.user_begin(), m_Not(m_Value()))); - // This transform adds a xor operation and that extra cost needs to be - // justified. We look for simplifications that will result from - // applying this rule: - - bool Profitable = - (LHS->hasNUses(2) && match(LHS, m_Not(m_Value()))) || - (RHS->hasNUses(2) && match(RHS, m_Not(m_Value()))) || - (SI.hasOneUse() && match(*SI.user_begin(), m_Not(m_Value()))); - - if (Profitable) { - Value *NewLHS = Builder->CreateNot(LHS); - Value *NewRHS = Builder->CreateNot(RHS); - Value *NewCmp = SPF == SPF_SMAX - ? Builder->CreateICmpSLT(NewLHS, NewRHS) - : Builder->CreateICmpULT(NewLHS, NewRHS); - Value *NewSI = - Builder->CreateNot(Builder->CreateSelect(NewCmp, NewLHS, NewRHS)); - return replaceInstUsesWith(SI, NewSI); - } + if (NumberOfNots >= 2) { + Value *NewLHS = Builder->CreateNot(LHS); + Value *NewRHS = Builder->CreateNot(RHS); + Value *NewCmp = SPF == SPF_SMAX + ? Builder->CreateICmpSLT(NewLHS, NewRHS) + : Builder->CreateICmpULT(NewLHS, NewRHS); + Value *NewSI = + Builder->CreateNot(Builder->CreateSelect(NewCmp, NewLHS, NewRHS)); + return replaceInstUsesWith(SI, NewSI); } } @@ -1182,8 +1366,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // See if we can fold the select into a phi node if the condition is a select. if (isa<PHINode>(SI.getCondition())) // The true/false values have to be live in the PHI predecessor's blocks. - if (CanSelectOperandBeMappingIntoPredBlock(TrueVal, SI) && - CanSelectOperandBeMappingIntoPredBlock(FalseVal, SI)) + if (canSelectOperandBeMappingIntoPredBlock(TrueVal, SI) && + canSelectOperandBeMappingIntoPredBlock(FalseVal, SI)) if (Instruction *NV = FoldOpIntoPhi(SI)) return NV; @@ -1233,7 +1417,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return &SI; } - if (VectorType* VecTy = dyn_cast<VectorType>(SelType)) { + if (VectorType *VecTy = dyn_cast<VectorType>(SelType)) { unsigned VWidth = VecTy->getNumElements(); APInt UndefElts(VWidth, 0); APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); @@ -1266,5 +1450,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } + if (Instruction *BitCastSel = foldSelectCmpBitcasts(SI, *Builder)) + return BitCastSel; + return nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index 08e16a7ee1af..bc38c4aca348 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -39,10 +39,19 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I)) return Res; + // (C1 shift (A add C2)) -> (C1 shift C2) shift A) + // iff A and C2 are both positive. + Value *A; + Constant *C; + if (match(Op0, m_Constant()) && match(Op1, m_Add(m_Value(A), m_Constant(C)))) + if (isKnownNonNegative(A, DL) && isKnownNonNegative(C, DL)) + return BinaryOperator::Create( + I.getOpcode(), Builder->CreateBinOp(I.getOpcode(), Op0, C), A); + // X shift (A srem B) -> X shift (A and B-1) iff B is a power of 2. // Because shifts by negative values (which could occur if A were negative) // are undefined. - Value *A; const APInt *B; + const APInt *B; if (Op1->hasOneUse() && match(Op1, m_SRem(m_Value(A), m_Power2(B)))) { // FIXME: Should this get moved into SimplifyDemandedBits by saying we don't // demand the sign bit (and many others) here?? @@ -194,8 +203,10 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, else V = IC.Builder->CreateLShr(C, NumBits); // If we got a constantexpr back, try to simplify it with TD info. - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) - V = ConstantFoldConstantExpression(CE, DL, IC.getTargetLibraryInfo()); + if (auto *C = dyn_cast<Constant>(V)) + if (auto *FoldedC = + ConstantFoldConstant(C, DL, &IC.getTargetLibraryInfo())) + V = FoldedC; return V; } @@ -317,7 +328,167 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, } } +/// Try to fold (X << C1) << C2, where the shifts are some combination of +/// shl/ashr/lshr. +static Instruction * +foldShiftByConstOfShiftByConst(BinaryOperator &I, ConstantInt *COp1, + InstCombiner::BuilderTy *Builder) { + Value *Op0 = I.getOperand(0); + uint32_t TypeBits = Op0->getType()->getScalarSizeInBits(); + + // Find out if this is a shift of a shift by a constant. + BinaryOperator *ShiftOp = dyn_cast<BinaryOperator>(Op0); + if (ShiftOp && !ShiftOp->isShift()) + ShiftOp = nullptr; + + if (ShiftOp && isa<ConstantInt>(ShiftOp->getOperand(1))) { + + // This is a constant shift of a constant shift. Be careful about hiding + // shl instructions behind bit masks. They are used to represent multiplies + // by a constant, and it is important that simple arithmetic expressions + // are still recognizable by scalar evolution. + // + // The transforms applied to shl are very similar to the transforms applied + // to mul by constant. We can be more aggressive about optimizing right + // shifts. + // + // Combinations of right and left shifts will still be optimized in + // DAGCombine where scalar evolution no longer applies. + + ConstantInt *ShiftAmt1C = cast<ConstantInt>(ShiftOp->getOperand(1)); + uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits); + uint32_t ShiftAmt2 = COp1->getLimitedValue(TypeBits); + assert(ShiftAmt2 != 0 && "Should have been simplified earlier"); + if (ShiftAmt1 == 0) + return nullptr; // Will be simplified in the future. + Value *X = ShiftOp->getOperand(0); + + IntegerType *Ty = cast<IntegerType>(I.getType()); + + // Check for (X << c1) << c2 and (X >> c1) >> c2 + if (I.getOpcode() == ShiftOp->getOpcode()) { + uint32_t AmtSum = ShiftAmt1 + ShiftAmt2; // Fold into one big shift. + // If this is an oversized composite shift, then unsigned shifts become + // zero (handled in InstSimplify) and ashr saturates. + if (AmtSum >= TypeBits) { + if (I.getOpcode() != Instruction::AShr) + return nullptr; + AmtSum = TypeBits - 1; // Saturate to 31 for i32 ashr. + } + + return BinaryOperator::Create(I.getOpcode(), X, + ConstantInt::get(Ty, AmtSum)); + } + + if (ShiftAmt1 == ShiftAmt2) { + // If we have ((X << C) >>u C), turn this into X & (-1 >>u C). + if (I.getOpcode() == Instruction::LShr && + ShiftOp->getOpcode() == Instruction::Shl) { + APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt1)); + return BinaryOperator::CreateAnd( + X, ConstantInt::get(I.getContext(), Mask)); + } + } else if (ShiftAmt1 < ShiftAmt2) { + uint32_t ShiftDiff = ShiftAmt2 - ShiftAmt1; + + // (X >>?,exact C1) << C2 --> X << (C2-C1) + // The inexact version is deferred to DAGCombine so we don't hide shl + // behind a bit mask. + if (I.getOpcode() == Instruction::Shl && + ShiftOp->getOpcode() != Instruction::Shl && ShiftOp->isExact()) { + assert(ShiftOp->getOpcode() == Instruction::LShr || + ShiftOp->getOpcode() == Instruction::AShr); + ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); + BinaryOperator *NewShl = + BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst); + NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); + return NewShl; + } + + // (X << C1) >>u C2 --> X >>u (C2-C1) & (-1 >> C2) + if (I.getOpcode() == Instruction::LShr && + ShiftOp->getOpcode() == Instruction::Shl) { + ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); + // (X <<nuw C1) >>u C2 --> X >>u (C2-C1) + if (ShiftOp->hasNoUnsignedWrap()) { + BinaryOperator *NewLShr = + BinaryOperator::Create(Instruction::LShr, X, ShiftDiffCst); + NewLShr->setIsExact(I.isExact()); + return NewLShr; + } + Value *Shift = Builder->CreateLShr(X, ShiftDiffCst); + + APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); + return BinaryOperator::CreateAnd( + Shift, ConstantInt::get(I.getContext(), Mask)); + } + + // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However, + // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. + if (I.getOpcode() == Instruction::AShr && + ShiftOp->getOpcode() == Instruction::Shl) { + if (ShiftOp->hasNoSignedWrap()) { + // (X <<nsw C1) >>s C2 --> X >>s (C2-C1) + ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); + BinaryOperator *NewAShr = + BinaryOperator::Create(Instruction::AShr, X, ShiftDiffCst); + NewAShr->setIsExact(I.isExact()); + return NewAShr; + } + } + } else { + assert(ShiftAmt2 < ShiftAmt1); + uint32_t ShiftDiff = ShiftAmt1 - ShiftAmt2; + + // (X >>?exact C1) << C2 --> X >>?exact (C1-C2) + // The inexact version is deferred to DAGCombine so we don't hide shl + // behind a bit mask. + if (I.getOpcode() == Instruction::Shl && + ShiftOp->getOpcode() != Instruction::Shl && ShiftOp->isExact()) { + ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); + BinaryOperator *NewShr = + BinaryOperator::Create(ShiftOp->getOpcode(), X, ShiftDiffCst); + NewShr->setIsExact(true); + return NewShr; + } + + // (X << C1) >>u C2 --> X << (C1-C2) & (-1 >> C2) + if (I.getOpcode() == Instruction::LShr && + ShiftOp->getOpcode() == Instruction::Shl) { + ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); + if (ShiftOp->hasNoUnsignedWrap()) { + // (X <<nuw C1) >>u C2 --> X <<nuw (C1-C2) + BinaryOperator *NewShl = + BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst); + NewShl->setHasNoUnsignedWrap(true); + return NewShl; + } + Value *Shift = Builder->CreateShl(X, ShiftDiffCst); + APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); + return BinaryOperator::CreateAnd( + Shift, ConstantInt::get(I.getContext(), Mask)); + } + + // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However, + // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. + if (I.getOpcode() == Instruction::AShr && + ShiftOp->getOpcode() == Instruction::Shl) { + if (ShiftOp->hasNoSignedWrap()) { + // (X <<nsw C1) >>s C2 --> X <<nsw (C1-C2) + ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); + BinaryOperator *NewShl = + BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst); + NewShl->setHasNoSignedWrap(true); + return NewShl; + } + } + } + } + + return nullptr; +} Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I) { @@ -455,9 +626,9 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, V1->getName()+".mask"); return BinaryOperator::Create(Op0BO->getOpcode(), YS, XM); } + LLVM_FALLTHROUGH; } - // FALL THROUGH. case Instruction::Sub: { // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C) if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() && @@ -539,157 +710,9 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, } } - // Find out if this is a shift of a shift by a constant. - BinaryOperator *ShiftOp = dyn_cast<BinaryOperator>(Op0); - if (ShiftOp && !ShiftOp->isShift()) - ShiftOp = nullptr; - - if (ShiftOp && isa<ConstantInt>(ShiftOp->getOperand(1))) { - - // This is a constant shift of a constant shift. Be careful about hiding - // shl instructions behind bit masks. They are used to represent multiplies - // by a constant, and it is important that simple arithmetic expressions - // are still recognizable by scalar evolution. - // - // The transforms applied to shl are very similar to the transforms applied - // to mul by constant. We can be more aggressive about optimizing right - // shifts. - // - // Combinations of right and left shifts will still be optimized in - // DAGCombine where scalar evolution no longer applies. - - ConstantInt *ShiftAmt1C = cast<ConstantInt>(ShiftOp->getOperand(1)); - uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits); - uint32_t ShiftAmt2 = COp1->getLimitedValue(TypeBits); - assert(ShiftAmt2 != 0 && "Should have been simplified earlier"); - if (ShiftAmt1 == 0) return nullptr; // Will be simplified in the future. - Value *X = ShiftOp->getOperand(0); - - IntegerType *Ty = cast<IntegerType>(I.getType()); - - // Check for (X << c1) << c2 and (X >> c1) >> c2 - if (I.getOpcode() == ShiftOp->getOpcode()) { - uint32_t AmtSum = ShiftAmt1+ShiftAmt2; // Fold into one big shift. - // If this is oversized composite shift, then unsigned shifts get 0, ashr - // saturates. - if (AmtSum >= TypeBits) { - if (I.getOpcode() != Instruction::AShr) - return replaceInstUsesWith(I, Constant::getNullValue(I.getType())); - AmtSum = TypeBits-1; // Saturate to 31 for i32 ashr. - } - - return BinaryOperator::Create(I.getOpcode(), X, - ConstantInt::get(Ty, AmtSum)); - } - - if (ShiftAmt1 == ShiftAmt2) { - // If we have ((X << C) >>u C), turn this into X & (-1 >>u C). - if (I.getOpcode() == Instruction::LShr && - ShiftOp->getOpcode() == Instruction::Shl) { - APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt1)); - return BinaryOperator::CreateAnd(X, - ConstantInt::get(I.getContext(), Mask)); - } - } else if (ShiftAmt1 < ShiftAmt2) { - uint32_t ShiftDiff = ShiftAmt2-ShiftAmt1; + if (Instruction *Folded = foldShiftByConstOfShiftByConst(I, COp1, Builder)) + return Folded; - // (X >>?,exact C1) << C2 --> X << (C2-C1) - // The inexact version is deferred to DAGCombine so we don't hide shl - // behind a bit mask. - if (I.getOpcode() == Instruction::Shl && - ShiftOp->getOpcode() != Instruction::Shl && - ShiftOp->isExact()) { - assert(ShiftOp->getOpcode() == Instruction::LShr || - ShiftOp->getOpcode() == Instruction::AShr); - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - BinaryOperator *NewShl = BinaryOperator::Create(Instruction::Shl, - X, ShiftDiffCst); - NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); - NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); - return NewShl; - } - - // (X << C1) >>u C2 --> X >>u (C2-C1) & (-1 >> C2) - if (I.getOpcode() == Instruction::LShr && - ShiftOp->getOpcode() == Instruction::Shl) { - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - // (X <<nuw C1) >>u C2 --> X >>u (C2-C1) - if (ShiftOp->hasNoUnsignedWrap()) { - BinaryOperator *NewLShr = BinaryOperator::Create(Instruction::LShr, - X, ShiftDiffCst); - NewLShr->setIsExact(I.isExact()); - return NewLShr; - } - Value *Shift = Builder->CreateLShr(X, ShiftDiffCst); - - APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); - return BinaryOperator::CreateAnd(Shift, - ConstantInt::get(I.getContext(),Mask)); - } - - // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However, - // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. - if (I.getOpcode() == Instruction::AShr && - ShiftOp->getOpcode() == Instruction::Shl) { - if (ShiftOp->hasNoSignedWrap()) { - // (X <<nsw C1) >>s C2 --> X >>s (C2-C1) - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - BinaryOperator *NewAShr = BinaryOperator::Create(Instruction::AShr, - X, ShiftDiffCst); - NewAShr->setIsExact(I.isExact()); - return NewAShr; - } - } - } else { - assert(ShiftAmt2 < ShiftAmt1); - uint32_t ShiftDiff = ShiftAmt1-ShiftAmt2; - - // (X >>?exact C1) << C2 --> X >>?exact (C1-C2) - // The inexact version is deferred to DAGCombine so we don't hide shl - // behind a bit mask. - if (I.getOpcode() == Instruction::Shl && - ShiftOp->getOpcode() != Instruction::Shl && - ShiftOp->isExact()) { - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - BinaryOperator *NewShr = BinaryOperator::Create(ShiftOp->getOpcode(), - X, ShiftDiffCst); - NewShr->setIsExact(true); - return NewShr; - } - - // (X << C1) >>u C2 --> X << (C1-C2) & (-1 >> C2) - if (I.getOpcode() == Instruction::LShr && - ShiftOp->getOpcode() == Instruction::Shl) { - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - if (ShiftOp->hasNoUnsignedWrap()) { - // (X <<nuw C1) >>u C2 --> X <<nuw (C1-C2) - BinaryOperator *NewShl = BinaryOperator::Create(Instruction::Shl, - X, ShiftDiffCst); - NewShl->setHasNoUnsignedWrap(true); - return NewShl; - } - Value *Shift = Builder->CreateShl(X, ShiftDiffCst); - - APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); - return BinaryOperator::CreateAnd(Shift, - ConstantInt::get(I.getContext(),Mask)); - } - - // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However, - // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. - if (I.getOpcode() == Instruction::AShr && - ShiftOp->getOpcode() == Instruction::Shl) { - if (ShiftOp->hasNoSignedWrap()) { - // (X <<nsw C1) >>s C2 --> X <<nsw (C1-C2) - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - BinaryOperator *NewShl = BinaryOperator::Create(Instruction::Shl, - X, ShiftDiffCst); - NewShl->setHasNoSignedWrap(true); - return NewShl; - } - } - } - } return nullptr; } @@ -699,7 +722,7 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), - I.hasNoUnsignedWrap(), DL, TLI, DT, AC)) + I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); if (Instruction *V = commonShiftTransforms(I)) @@ -740,7 +763,7 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), - DL, TLI, DT, AC)) + DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) @@ -784,7 +807,7 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), - DL, TLI, DT, AC)) + DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index f3268d2c3471..8b930bd95dfe 100644 --- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -981,6 +981,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, bool MadeChange = false; APInt UndefElts2(VWidth, 0); + APInt UndefElts3(VWidth, 0); Value *TmpV; switch (I->getOpcode()) { default: break; @@ -1020,8 +1021,8 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, } case Instruction::ShuffleVector: { ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I); - uint64_t LHSVWidth = - cast<VectorType>(Shuffle->getOperand(0)->getType())->getNumElements(); + unsigned LHSVWidth = + Shuffle->getOperand(0)->getType()->getVectorNumElements(); APInt LeftDemanded(LHSVWidth, 0), RightDemanded(LHSVWidth, 0); for (unsigned i = 0; i < VWidth; i++) { if (DemandedElts[i]) { @@ -1037,17 +1038,21 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, } } - APInt UndefElts4(LHSVWidth, 0); + APInt LHSUndefElts(LHSVWidth, 0); TmpV = SimplifyDemandedVectorElts(I->getOperand(0), LeftDemanded, - UndefElts4, Depth + 1); + LHSUndefElts, Depth + 1); if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; } - APInt UndefElts3(LHSVWidth, 0); + APInt RHSUndefElts(LHSVWidth, 0); TmpV = SimplifyDemandedVectorElts(I->getOperand(1), RightDemanded, - UndefElts3, Depth + 1); + RHSUndefElts, Depth + 1); if (TmpV) { I->setOperand(1, TmpV); MadeChange = true; } bool NewUndefElts = false; + unsigned LHSIdx = -1u, LHSValIdx = -1u; + unsigned RHSIdx = -1u, RHSValIdx = -1u; + bool LHSUniform = true; + bool RHSUniform = true; for (unsigned i = 0; i < VWidth; i++) { unsigned MaskVal = Shuffle->getMaskValue(i); if (MaskVal == -1u) { @@ -1056,18 +1061,59 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, NewUndefElts = true; UndefElts.setBit(i); } else if (MaskVal < LHSVWidth) { - if (UndefElts4[MaskVal]) { + if (LHSUndefElts[MaskVal]) { NewUndefElts = true; UndefElts.setBit(i); + } else { + LHSIdx = LHSIdx == -1u ? i : LHSVWidth; + LHSValIdx = LHSValIdx == -1u ? MaskVal : LHSVWidth; + LHSUniform = LHSUniform && (MaskVal == i); } } else { - if (UndefElts3[MaskVal - LHSVWidth]) { + if (RHSUndefElts[MaskVal - LHSVWidth]) { NewUndefElts = true; UndefElts.setBit(i); + } else { + RHSIdx = RHSIdx == -1u ? i : LHSVWidth; + RHSValIdx = RHSValIdx == -1u ? MaskVal - LHSVWidth : LHSVWidth; + RHSUniform = RHSUniform && (MaskVal - LHSVWidth == i); } } } + // Try to transform shuffle with constant vector and single element from + // this constant vector to single insertelement instruction. + // shufflevector V, C, <v1, v2, .., ci, .., vm> -> + // insertelement V, C[ci], ci-n + if (LHSVWidth == Shuffle->getType()->getNumElements()) { + Value *Op = nullptr; + Constant *Value = nullptr; + unsigned Idx = -1u; + + // Find constant vector with the single element in shuffle (LHS or RHS). + if (LHSIdx < LHSVWidth && RHSUniform) { + if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(0))) { + Op = Shuffle->getOperand(1); + Value = CV->getOperand(LHSValIdx); + Idx = LHSIdx; + } + } + if (RHSIdx < LHSVWidth && LHSUniform) { + if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(1))) { + Op = Shuffle->getOperand(0); + Value = CV->getOperand(RHSValIdx); + Idx = RHSIdx; + } + } + // Found constant vector with single element - convert to insertelement. + if (Op && Value) { + Instruction *New = InsertElementInst::Create( + Op, Value, ConstantInt::get(Type::getInt32Ty(I->getContext()), Idx), + Shuffle->getName()); + InsertNewInstWith(New, *Shuffle); + return New; + } + } if (NewUndefElts) { // Add additional discovered undefs. SmallVector<Constant*, 16> Elts; @@ -1209,114 +1255,223 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, switch (II->getIntrinsicID()) { default: break; + case Intrinsic::x86_xop_vfrcz_ss: + case Intrinsic::x86_xop_vfrcz_sd: + // The instructions for these intrinsics are speced to zero upper bits not + // pass them through like other scalar intrinsics. So we shouldn't just + // use Arg0 if DemandedElts[0] is clear like we do for other intrinsics. + // Instead we should return a zero vector. + if (!DemandedElts[0]) { + Worklist.Add(II); + return ConstantAggregateZero::get(II->getType()); + } + + // Only the lower element is used. + DemandedElts = 1; + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, + UndefElts, Depth + 1); + if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } + + // Only the lower element is undefined. The high elements are zero. + UndefElts = UndefElts[0]; + break; + // Unary scalar-as-vector operations that work column-wise. case Intrinsic::x86_sse_rcp_ss: case Intrinsic::x86_sse_rsqrt_ss: case Intrinsic::x86_sse_sqrt_ss: case Intrinsic::x86_sse2_sqrt_sd: - case Intrinsic::x86_xop_vfrcz_ss: - case Intrinsic::x86_xop_vfrcz_sd: TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, UndefElts, Depth + 1); if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } // If lowest element of a scalar op isn't used then use Arg0. - if (DemandedElts.getLoBits(1) != 1) + if (!DemandedElts[0]) { + Worklist.Add(II); return II->getArgOperand(0); + } // TODO: If only low elt lower SQRT to FSQRT (with rounding/exceptions // checks). break; - // Binary scalar-as-vector operations that work column-wise. A dest element - // is a function of the corresponding input elements from the two inputs. - case Intrinsic::x86_sse_add_ss: - case Intrinsic::x86_sse_sub_ss: - case Intrinsic::x86_sse_mul_ss: - case Intrinsic::x86_sse_div_ss: + // Binary scalar-as-vector operations that work column-wise. The high + // elements come from operand 0. The low element is a function of both + // operands. case Intrinsic::x86_sse_min_ss: case Intrinsic::x86_sse_max_ss: case Intrinsic::x86_sse_cmp_ss: - case Intrinsic::x86_sse2_add_sd: - case Intrinsic::x86_sse2_sub_sd: - case Intrinsic::x86_sse2_mul_sd: - case Intrinsic::x86_sse2_div_sd: case Intrinsic::x86_sse2_min_sd: case Intrinsic::x86_sse2_max_sd: - case Intrinsic::x86_sse2_cmp_sd: - case Intrinsic::x86_sse41_round_ss: - case Intrinsic::x86_sse41_round_sd: + case Intrinsic::x86_sse2_cmp_sd: { TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, UndefElts, Depth + 1); if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } + + // If lowest element of a scalar op isn't used then use Arg0. + if (!DemandedElts[0]) { + Worklist.Add(II); + return II->getArgOperand(0); + } + + // Only lower element is used for operand 1. + DemandedElts = 1; TmpV = SimplifyDemandedVectorElts(II->getArgOperand(1), DemandedElts, UndefElts2, Depth + 1); if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } - // If only the low elt is demanded and this is a scalarizable intrinsic, - // scalarize it now. - if (DemandedElts == 1) { - switch (II->getIntrinsicID()) { - default: break; - case Intrinsic::x86_sse_add_ss: - case Intrinsic::x86_sse_sub_ss: - case Intrinsic::x86_sse_mul_ss: - case Intrinsic::x86_sse_div_ss: - case Intrinsic::x86_sse2_add_sd: - case Intrinsic::x86_sse2_sub_sd: - case Intrinsic::x86_sse2_mul_sd: - case Intrinsic::x86_sse2_div_sd: - // TODO: Lower MIN/MAX/etc. - Value *LHS = II->getArgOperand(0); - Value *RHS = II->getArgOperand(1); - // Extract the element as scalars. - LHS = InsertNewInstWith(ExtractElementInst::Create(LHS, - ConstantInt::get(Type::getInt32Ty(I->getContext()), 0U)), *II); - RHS = InsertNewInstWith(ExtractElementInst::Create(RHS, - ConstantInt::get(Type::getInt32Ty(I->getContext()), 0U)), *II); + // Lower element is undefined if both lower elements are undefined. + // Consider things like undef&0. The result is known zero, not undef. + if (!UndefElts2[0]) + UndefElts.clearBit(0); - switch (II->getIntrinsicID()) { - default: llvm_unreachable("Case stmts out of sync!"); - case Intrinsic::x86_sse_add_ss: - case Intrinsic::x86_sse2_add_sd: - TmpV = InsertNewInstWith(BinaryOperator::CreateFAdd(LHS, RHS, - II->getName()), *II); - break; - case Intrinsic::x86_sse_sub_ss: - case Intrinsic::x86_sse2_sub_sd: - TmpV = InsertNewInstWith(BinaryOperator::CreateFSub(LHS, RHS, - II->getName()), *II); - break; - case Intrinsic::x86_sse_mul_ss: - case Intrinsic::x86_sse2_mul_sd: - TmpV = InsertNewInstWith(BinaryOperator::CreateFMul(LHS, RHS, - II->getName()), *II); - break; - case Intrinsic::x86_sse_div_ss: - case Intrinsic::x86_sse2_div_sd: - TmpV = InsertNewInstWith(BinaryOperator::CreateFDiv(LHS, RHS, - II->getName()), *II); - break; - } + break; + } - Instruction *New = - InsertElementInst::Create( - UndefValue::get(II->getType()), TmpV, - ConstantInt::get(Type::getInt32Ty(I->getContext()), 0U, false), - II->getName()); - InsertNewInstWith(New, *II); - return New; - } + // Binary scalar-as-vector operations that work column-wise. The high + // elements come from operand 0 and the low element comes from operand 1. + case Intrinsic::x86_sse41_round_ss: + case Intrinsic::x86_sse41_round_sd: { + // Don't use the low element of operand 0. + APInt DemandedElts2 = DemandedElts; + DemandedElts2.clearBit(0); + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts2, + UndefElts, Depth + 1); + if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } + + // If lowest element of a scalar op isn't used then use Arg0. + if (!DemandedElts[0]) { + Worklist.Add(II); + return II->getArgOperand(0); } + // Only lower element is used for operand 1. + DemandedElts = 1; + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(1), DemandedElts, + UndefElts2, Depth + 1); + if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } + + // Take the high undef elements from operand 0 and take the lower element + // from operand 1. + UndefElts.clearBit(0); + UndefElts |= UndefElts2[0]; + break; + } + + // Three input scalar-as-vector operations that work column-wise. The high + // elements come from operand 0 and the low element is a function of all + // three inputs. + case Intrinsic::x86_avx512_mask_add_ss_round: + case Intrinsic::x86_avx512_mask_div_ss_round: + case Intrinsic::x86_avx512_mask_mul_ss_round: + case Intrinsic::x86_avx512_mask_sub_ss_round: + case Intrinsic::x86_avx512_mask_max_ss_round: + case Intrinsic::x86_avx512_mask_min_ss_round: + case Intrinsic::x86_avx512_mask_add_sd_round: + case Intrinsic::x86_avx512_mask_div_sd_round: + case Intrinsic::x86_avx512_mask_mul_sd_round: + case Intrinsic::x86_avx512_mask_sub_sd_round: + case Intrinsic::x86_avx512_mask_max_sd_round: + case Intrinsic::x86_avx512_mask_min_sd_round: + case Intrinsic::x86_fma_vfmadd_ss: + case Intrinsic::x86_fma_vfmsub_ss: + case Intrinsic::x86_fma_vfnmadd_ss: + case Intrinsic::x86_fma_vfnmsub_ss: + case Intrinsic::x86_fma_vfmadd_sd: + case Intrinsic::x86_fma_vfmsub_sd: + case Intrinsic::x86_fma_vfnmadd_sd: + case Intrinsic::x86_fma_vfnmsub_sd: + case Intrinsic::x86_avx512_mask_vfmadd_ss: + case Intrinsic::x86_avx512_mask_vfmadd_sd: + case Intrinsic::x86_avx512_maskz_vfmadd_ss: + case Intrinsic::x86_avx512_maskz_vfmadd_sd: + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, + UndefElts, Depth + 1); + if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } + // If lowest element of a scalar op isn't used then use Arg0. - if (DemandedElts.getLoBits(1) != 1) + if (!DemandedElts[0]) { + Worklist.Add(II); return II->getArgOperand(0); + } + + // Only lower element is used for operand 1 and 2. + DemandedElts = 1; + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(1), DemandedElts, + UndefElts2, Depth + 1); + if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(2), DemandedElts, + UndefElts3, Depth + 1); + if (TmpV) { II->setArgOperand(2, TmpV); MadeChange = true; } + + // Lower element is undefined if all three lower elements are undefined. + // Consider things like undef&0. The result is known zero, not undef. + if (!UndefElts2[0] || !UndefElts3[0]) + UndefElts.clearBit(0); + + break; + + case Intrinsic::x86_avx512_mask3_vfmadd_ss: + case Intrinsic::x86_avx512_mask3_vfmadd_sd: + case Intrinsic::x86_avx512_mask3_vfmsub_ss: + case Intrinsic::x86_avx512_mask3_vfmsub_sd: + case Intrinsic::x86_avx512_mask3_vfnmsub_ss: + case Intrinsic::x86_avx512_mask3_vfnmsub_sd: + // These intrinsics get the passthru bits from operand 2. + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(2), DemandedElts, + UndefElts, Depth + 1); + if (TmpV) { II->setArgOperand(2, TmpV); MadeChange = true; } + + // If lowest element of a scalar op isn't used then use Arg2. + if (!DemandedElts[0]) { + Worklist.Add(II); + return II->getArgOperand(2); + } + + // Only lower element is used for operand 0 and 1. + DemandedElts = 1; + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, + UndefElts2, Depth + 1); + if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(1), DemandedElts, + UndefElts3, Depth + 1); + if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } + + // Lower element is undefined if all three lower elements are undefined. + // Consider things like undef&0. The result is known zero, not undef. + if (!UndefElts2[0] || !UndefElts3[0]) + UndefElts.clearBit(0); - // Output elements are undefined if both are undefined. Consider things - // like undef&0. The result is known zero, not undef. - UndefElts &= UndefElts2; break; + case Intrinsic::x86_sse2_pmulu_dq: + case Intrinsic::x86_sse41_pmuldq: + case Intrinsic::x86_avx2_pmul_dq: + case Intrinsic::x86_avx2_pmulu_dq: + case Intrinsic::x86_avx512_pmul_dq_512: + case Intrinsic::x86_avx512_pmulu_dq_512: { + Value *Op0 = II->getArgOperand(0); + Value *Op1 = II->getArgOperand(1); + unsigned InnerVWidth = Op0->getType()->getVectorNumElements(); + assert((VWidth * 2) == InnerVWidth && "Unexpected input size"); + + APInt InnerDemandedElts(InnerVWidth, 0); + for (unsigned i = 0; i != VWidth; ++i) + if (DemandedElts[i]) + InnerDemandedElts.setBit(i * 2); + + UndefElts2 = APInt(InnerVWidth, 0); + TmpV = SimplifyDemandedVectorElts(Op0, InnerDemandedElts, UndefElts2, + Depth + 1); + if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } + + UndefElts3 = APInt(InnerVWidth, 0); + TmpV = SimplifyDemandedVectorElts(Op1, InnerDemandedElts, UndefElts3, + Depth + 1); + if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } + + break; + } + // SSE4A instructions leave the upper 64-bits of the 128-bit result // in an undefined state. case Intrinsic::x86_sse4a_extrq: diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index a76138756148..b2477f6c8633 100644 --- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -145,7 +145,7 @@ Instruction *InstCombiner::scalarizePHI(ExtractElementInst &EI, PHINode *PN) { Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { if (Value *V = SimplifyExtractElementInst( - EI.getVectorOperand(), EI.getIndexOperand(), DL, TLI, DT, AC)) + EI.getVectorOperand(), EI.getIndexOperand(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(EI, V); // If vector val is constant with all elements the same, replace EI with @@ -413,6 +413,14 @@ static void replaceExtractElements(InsertElementInst *InsElt, if (InsertionBlock != InsElt->getParent()) return; + // TODO: This restriction matches the check in visitInsertElementInst() and + // prevents an infinite loop caused by not turning the extract/insert pair + // into a shuffle. We really should not need either check, but we're lacking + // folds for shufflevectors because we're afraid to generate shuffle masks + // that the backend can't handle. + if (InsElt->hasOneUse() && isa<InsertElementInst>(InsElt->user_back())) + return; + auto *WideVec = new ShuffleVectorInst(ExtVecOp, UndefValue::get(ExtVecType), ConstantVector::get(ExtendMask)); @@ -452,7 +460,7 @@ static ShuffleOps collectShuffleElements(Value *V, Value *PermittedRHS, InstCombiner &IC) { assert(V->getType()->isVectorTy() && "Invalid shuffle!"); - unsigned NumElts = cast<VectorType>(V->getType())->getNumElements(); + unsigned NumElts = V->getType()->getVectorNumElements(); if (isa<UndefValue>(V)) { Mask.assign(NumElts, UndefValue::get(Type::getInt32Ty(V->getContext()))); @@ -566,6 +574,176 @@ Instruction *InstCombiner::visitInsertValueInst(InsertValueInst &I) { return nullptr; } +static bool isShuffleEquivalentToSelect(ShuffleVectorInst &Shuf) { + int MaskSize = Shuf.getMask()->getType()->getVectorNumElements(); + int VecSize = Shuf.getOperand(0)->getType()->getVectorNumElements(); + + // A vector select does not change the size of the operands. + if (MaskSize != VecSize) + return false; + + // Each mask element must be undefined or choose a vector element from one of + // the source operands without crossing vector lanes. + for (int i = 0; i != MaskSize; ++i) { + int Elt = Shuf.getMaskValue(i); + if (Elt != -1 && Elt != i && Elt != i + VecSize) + return false; + } + + return true; +} + +// Turn a chain of inserts that splats a value into a canonical insert + shuffle +// splat. That is: +// insertelt(insertelt(insertelt(insertelt X, %k, 0), %k, 1), %k, 2) ... -> +// shufflevector(insertelt(X, %k, 0), undef, zero) +static Instruction *foldInsSequenceIntoBroadcast(InsertElementInst &InsElt) { + // We are interested in the last insert in a chain. So, if this insert + // has a single user, and that user is an insert, bail. + if (InsElt.hasOneUse() && isa<InsertElementInst>(InsElt.user_back())) + return nullptr; + + VectorType *VT = cast<VectorType>(InsElt.getType()); + int NumElements = VT->getNumElements(); + + // Do not try to do this for a one-element vector, since that's a nop, + // and will cause an inf-loop. + if (NumElements == 1) + return nullptr; + + Value *SplatVal = InsElt.getOperand(1); + InsertElementInst *CurrIE = &InsElt; + SmallVector<bool, 16> ElementPresent(NumElements, false); + + // Walk the chain backwards, keeping track of which indices we inserted into, + // until we hit something that isn't an insert of the splatted value. + while (CurrIE) { + ConstantInt *Idx = dyn_cast<ConstantInt>(CurrIE->getOperand(2)); + if (!Idx || CurrIE->getOperand(1) != SplatVal) + return nullptr; + + // Check none of the intermediate steps have any additional uses. + if ((CurrIE != &InsElt) && !CurrIE->hasOneUse()) + return nullptr; + + ElementPresent[Idx->getZExtValue()] = true; + CurrIE = dyn_cast<InsertElementInst>(CurrIE->getOperand(0)); + } + + // Make sure we've seen an insert into every element. + if (llvm::any_of(ElementPresent, [](bool Present) { return !Present; })) + return nullptr; + + // All right, create the insert + shuffle. + Instruction *InsertFirst = InsertElementInst::Create( + UndefValue::get(VT), SplatVal, + ConstantInt::get(Type::getInt32Ty(InsElt.getContext()), 0), "", &InsElt); + + Constant *ZeroMask = ConstantAggregateZero::get( + VectorType::get(Type::getInt32Ty(InsElt.getContext()), NumElements)); + + return new ShuffleVectorInst(InsertFirst, UndefValue::get(VT), ZeroMask); +} + +/// insertelt (shufflevector X, CVec, Mask|insertelt X, C1, CIndex1), C, CIndex +/// --> shufflevector X, CVec', Mask' +static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) { + auto *Inst = dyn_cast<Instruction>(InsElt.getOperand(0)); + // Bail out if the parent has more than one use. In that case, we'd be + // replacing the insertelt with a shuffle, and that's not a clear win. + if (!Inst || !Inst->hasOneUse()) + return nullptr; + if (auto *Shuf = dyn_cast<ShuffleVectorInst>(InsElt.getOperand(0))) { + // The shuffle must have a constant vector operand. The insertelt must have + // a constant scalar being inserted at a constant position in the vector. + Constant *ShufConstVec, *InsEltScalar; + uint64_t InsEltIndex; + if (!match(Shuf->getOperand(1), m_Constant(ShufConstVec)) || + !match(InsElt.getOperand(1), m_Constant(InsEltScalar)) || + !match(InsElt.getOperand(2), m_ConstantInt(InsEltIndex))) + return nullptr; + + // Adding an element to an arbitrary shuffle could be expensive, but a + // shuffle that selects elements from vectors without crossing lanes is + // assumed cheap. + // If we're just adding a constant into that shuffle, it will still be + // cheap. + if (!isShuffleEquivalentToSelect(*Shuf)) + return nullptr; + + // From the above 'select' check, we know that the mask has the same number + // of elements as the vector input operands. We also know that each constant + // input element is used in its lane and can not be used more than once by + // the shuffle. Therefore, replace the constant in the shuffle's constant + // vector with the insertelt constant. Replace the constant in the shuffle's + // mask vector with the insertelt index plus the length of the vector + // (because the constant vector operand of a shuffle is always the 2nd + // operand). + Constant *Mask = Shuf->getMask(); + unsigned NumElts = Mask->getType()->getVectorNumElements(); + SmallVector<Constant *, 16> NewShufElts(NumElts); + SmallVector<Constant *, 16> NewMaskElts(NumElts); + for (unsigned I = 0; I != NumElts; ++I) { + if (I == InsEltIndex) { + NewShufElts[I] = InsEltScalar; + Type *Int32Ty = Type::getInt32Ty(Shuf->getContext()); + NewMaskElts[I] = ConstantInt::get(Int32Ty, InsEltIndex + NumElts); + } else { + // Copy over the existing values. + NewShufElts[I] = ShufConstVec->getAggregateElement(I); + NewMaskElts[I] = Mask->getAggregateElement(I); + } + } + + // Create new operands for a shuffle that includes the constant of the + // original insertelt. The old shuffle will be dead now. + return new ShuffleVectorInst(Shuf->getOperand(0), + ConstantVector::get(NewShufElts), + ConstantVector::get(NewMaskElts)); + } else if (auto *IEI = dyn_cast<InsertElementInst>(Inst)) { + // Transform sequences of insertelements ops with constant data/indexes into + // a single shuffle op. + unsigned NumElts = InsElt.getType()->getNumElements(); + + uint64_t InsertIdx[2]; + Constant *Val[2]; + if (!match(InsElt.getOperand(2), m_ConstantInt(InsertIdx[0])) || + !match(InsElt.getOperand(1), m_Constant(Val[0])) || + !match(IEI->getOperand(2), m_ConstantInt(InsertIdx[1])) || + !match(IEI->getOperand(1), m_Constant(Val[1]))) + return nullptr; + SmallVector<Constant *, 16> Values(NumElts); + SmallVector<Constant *, 16> Mask(NumElts); + auto ValI = std::begin(Val); + // Generate new constant vector and mask. + // We have 2 values/masks from the insertelements instructions. Insert them + // into new value/mask vectors. + for (uint64_t I : InsertIdx) { + if (!Values[I]) { + assert(!Mask[I]); + Values[I] = *ValI; + Mask[I] = ConstantInt::get(Type::getInt32Ty(InsElt.getContext()), + NumElts + I); + } + ++ValI; + } + // Remaining values are filled with 'undef' values. + for (unsigned I = 0; I < NumElts; ++I) { + if (!Values[I]) { + assert(!Mask[I]); + Values[I] = UndefValue::get(InsElt.getType()->getElementType()); + Mask[I] = ConstantInt::get(Type::getInt32Ty(InsElt.getContext()), I); + } + } + // Create new operands for a shuffle that includes the constant of the + // original insertelt. + return new ShuffleVectorInst(IEI->getOperand(0), + ConstantVector::get(Values), + ConstantVector::get(Mask)); + } + return nullptr; +} + Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { Value *VecOp = IE.getOperand(0); Value *ScalarOp = IE.getOperand(1); @@ -616,7 +794,7 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { } } - unsigned VWidth = cast<VectorType>(VecOp->getType())->getNumElements(); + unsigned VWidth = VecOp->getType()->getVectorNumElements(); APInt UndefElts(VWidth, 0); APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); if (Value *V = SimplifyDemandedVectorElts(&IE, AllOnesEltMask, UndefElts)) { @@ -625,6 +803,14 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { return &IE; } + if (Instruction *Shuf = foldConstantInsEltIntoShuffle(IE)) + return Shuf; + + // Turn a sequence of inserts that broadcasts a scalar into a single + // insert + shufflevector. + if (Instruction *Broadcast = foldInsSequenceIntoBroadcast(IE)) + return Broadcast; + return nullptr; } @@ -903,8 +1089,7 @@ static void recognizeIdentityMask(const SmallVectorImpl<int> &Mask, // +--+--+--+--+ static bool isShuffleExtractingFromLHS(ShuffleVectorInst &SVI, SmallVector<int, 16> &Mask) { - unsigned LHSElems = - cast<VectorType>(SVI.getOperand(0)->getType())->getNumElements(); + unsigned LHSElems = SVI.getOperand(0)->getType()->getVectorNumElements(); unsigned MaskElems = Mask.size(); unsigned BegIdx = Mask.front(); unsigned EndIdx = Mask.back(); @@ -928,7 +1113,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (isa<UndefValue>(SVI.getOperand(2))) return replaceInstUsesWith(SVI, UndefValue::get(SVI.getType())); - unsigned VWidth = cast<VectorType>(SVI.getType())->getNumElements(); + unsigned VWidth = SVI.getType()->getVectorNumElements(); APInt UndefElts(VWidth, 0); APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); @@ -940,7 +1125,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { MadeChange = true; } - unsigned LHSWidth = cast<VectorType>(LHS->getType())->getNumElements(); + unsigned LHSWidth = LHS->getType()->getVectorNumElements(); // Canonicalize shuffle(x ,x,mask) -> shuffle(x, undef,mask') // Canonicalize shuffle(undef,x,mask) -> shuffle(x, undef,mask'). @@ -1143,11 +1328,11 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (LHSShuffle) { LHSOp0 = LHSShuffle->getOperand(0); LHSOp1 = LHSShuffle->getOperand(1); - LHSOp0Width = cast<VectorType>(LHSOp0->getType())->getNumElements(); + LHSOp0Width = LHSOp0->getType()->getVectorNumElements(); } if (RHSShuffle) { RHSOp0 = RHSShuffle->getOperand(0); - RHSOp0Width = cast<VectorType>(RHSOp0->getType())->getNumElements(); + RHSOp0Width = RHSOp0->getType()->getVectorNumElements(); } Value* newLHS = LHS; Value* newRHS = RHS; diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp index 377ccb9c37f7..9a52874c4c21 100644 --- a/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -177,11 +177,10 @@ static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1) { return false; // TODO: Enhance logic for other BinOps and remove this check. - auto AssocOpcode = BinOp1->getOpcode(); - if (AssocOpcode != Instruction::Xor && AssocOpcode != Instruction::And && - AssocOpcode != Instruction::Or) + if (!BinOp1->isBitwiseLogicOp()) return false; + auto AssocOpcode = BinOp1->getOpcode(); auto *BinOp2 = dyn_cast<BinaryOperator>(Cast->getOperand(0)); if (!BinOp2 || !BinOp2->hasOneUse() || BinOp2->getOpcode() != AssocOpcode) return false; @@ -684,14 +683,14 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { if (SI0->getCondition() == SI1->getCondition()) { Value *SI = nullptr; if (Value *V = SimplifyBinOp(TopLevelOpcode, SI0->getFalseValue(), - SI1->getFalseValue(), DL, TLI, DT, AC)) + SI1->getFalseValue(), DL, &TLI, &DT, &AC)) SI = Builder->CreateSelect(SI0->getCondition(), Builder->CreateBinOp(TopLevelOpcode, SI0->getTrueValue(), SI1->getTrueValue()), V); if (Value *V = SimplifyBinOp(TopLevelOpcode, SI0->getTrueValue(), - SI1->getTrueValue(), DL, TLI, DT, AC)) + SI1->getTrueValue(), DL, &TLI, &DT, &AC)) SI = Builder->CreateSelect( SI0->getCondition(), V, Builder->CreateBinOp(TopLevelOpcode, SI0->getFalseValue(), @@ -741,17 +740,18 @@ Value *InstCombiner::dyn_castFNegVal(Value *V, bool IgnoreZeroSign) const { return nullptr; } -static Value *FoldOperationIntoSelectOperand(Instruction &I, Value *SO, +static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO, InstCombiner *IC) { - if (CastInst *CI = dyn_cast<CastInst>(&I)) { - return IC->Builder->CreateCast(CI->getOpcode(), SO, I.getType()); - } + if (auto *Cast = dyn_cast<CastInst>(&I)) + return IC->Builder->CreateCast(Cast->getOpcode(), SO, I.getType()); + + assert(I.isBinaryOp() && "Unexpected opcode for select folding"); // Figure out if the constant is the left or the right argument. bool ConstIsRHS = isa<Constant>(I.getOperand(1)); Constant *ConstOperand = cast<Constant>(I.getOperand(ConstIsRHS)); - if (Constant *SOC = dyn_cast<Constant>(SO)) { + if (auto *SOC = dyn_cast<Constant>(SO)) { if (ConstIsRHS) return ConstantExpr::get(I.getOpcode(), SOC, ConstOperand); return ConstantExpr::get(I.getOpcode(), ConstOperand, SOC); @@ -761,21 +761,13 @@ static Value *FoldOperationIntoSelectOperand(Instruction &I, Value *SO, if (!ConstIsRHS) std::swap(Op0, Op1); - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(&I)) { - Value *RI = IC->Builder->CreateBinOp(BO->getOpcode(), Op0, Op1, - SO->getName()+".op"); - Instruction *FPInst = dyn_cast<Instruction>(RI); - if (FPInst && isa<FPMathOperator>(FPInst)) - FPInst->copyFastMathFlags(BO); - return RI; - } - if (ICmpInst *CI = dyn_cast<ICmpInst>(&I)) - return IC->Builder->CreateICmp(CI->getPredicate(), Op0, Op1, - SO->getName()+".cmp"); - if (FCmpInst *CI = dyn_cast<FCmpInst>(&I)) - return IC->Builder->CreateICmp(CI->getPredicate(), Op0, Op1, - SO->getName()+".cmp"); - llvm_unreachable("Unknown binary instruction type!"); + auto *BO = cast<BinaryOperator>(&I); + Value *RI = IC->Builder->CreateBinOp(BO->getOpcode(), Op0, Op1, + SO->getName() + ".op"); + auto *FPInst = dyn_cast<Instruction>(RI); + if (FPInst && isa<FPMathOperator>(FPInst)) + FPInst->copyFastMathFlags(BO); + return RI; } /// Given an instruction with a select as one operand and a constant as the @@ -783,51 +775,53 @@ static Value *FoldOperationIntoSelectOperand(Instruction &I, Value *SO, /// This also works for Cast instructions, which obviously do not have a second /// operand. Instruction *InstCombiner::FoldOpIntoSelect(Instruction &Op, SelectInst *SI) { - // Don't modify shared select instructions - if (!SI->hasOneUse()) return nullptr; - Value *TV = SI->getOperand(1); - Value *FV = SI->getOperand(2); + // Don't modify shared select instructions. + if (!SI->hasOneUse()) + return nullptr; - if (isa<Constant>(TV) || isa<Constant>(FV)) { - // Bool selects with constant operands can be folded to logical ops. - if (SI->getType()->isIntegerTy(1)) return nullptr; + Value *TV = SI->getTrueValue(); + Value *FV = SI->getFalseValue(); + if (!(isa<Constant>(TV) || isa<Constant>(FV))) + return nullptr; - // If it's a bitcast involving vectors, make sure it has the same number of - // elements on both sides. - if (BitCastInst *BC = dyn_cast<BitCastInst>(&Op)) { - VectorType *DestTy = dyn_cast<VectorType>(BC->getDestTy()); - VectorType *SrcTy = dyn_cast<VectorType>(BC->getSrcTy()); + // Bool selects with constant operands can be folded to logical ops. + if (SI->getType()->getScalarType()->isIntegerTy(1)) + return nullptr; - // Verify that either both or neither are vectors. - if ((SrcTy == nullptr) != (DestTy == nullptr)) return nullptr; - // If vectors, verify that they have the same number of elements. - if (SrcTy && SrcTy->getNumElements() != DestTy->getNumElements()) - return nullptr; - } + // If it's a bitcast involving vectors, make sure it has the same number of + // elements on both sides. + if (auto *BC = dyn_cast<BitCastInst>(&Op)) { + VectorType *DestTy = dyn_cast<VectorType>(BC->getDestTy()); + VectorType *SrcTy = dyn_cast<VectorType>(BC->getSrcTy()); - // Test if a CmpInst instruction is used exclusively by a select as - // part of a minimum or maximum operation. If so, refrain from doing - // any other folding. This helps out other analyses which understand - // non-obfuscated minimum and maximum idioms, such as ScalarEvolution - // and CodeGen. And in this case, at least one of the comparison - // operands has at least one user besides the compare (the select), - // which would often largely negate the benefit of folding anyway. - if (auto *CI = dyn_cast<CmpInst>(SI->getCondition())) { - if (CI->hasOneUse()) { - Value *Op0 = CI->getOperand(0), *Op1 = CI->getOperand(1); - if ((SI->getOperand(1) == Op0 && SI->getOperand(2) == Op1) || - (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1)) - return nullptr; - } - } + // Verify that either both or neither are vectors. + if ((SrcTy == nullptr) != (DestTy == nullptr)) + return nullptr; - Value *SelectTrueVal = FoldOperationIntoSelectOperand(Op, TV, this); - Value *SelectFalseVal = FoldOperationIntoSelectOperand(Op, FV, this); + // If vectors, verify that they have the same number of elements. + if (SrcTy && SrcTy->getNumElements() != DestTy->getNumElements()) + return nullptr; + } - return SelectInst::Create(SI->getCondition(), - SelectTrueVal, SelectFalseVal); + // Test if a CmpInst instruction is used exclusively by a select as + // part of a minimum or maximum operation. If so, refrain from doing + // any other folding. This helps out other analyses which understand + // non-obfuscated minimum and maximum idioms, such as ScalarEvolution + // and CodeGen. And in this case, at least one of the comparison + // operands has at least one user besides the compare (the select), + // which would often largely negate the benefit of folding anyway. + if (auto *CI = dyn_cast<CmpInst>(SI->getCondition())) { + if (CI->hasOneUse()) { + Value *Op0 = CI->getOperand(0), *Op1 = CI->getOperand(1); + if ((SI->getOperand(1) == Op0 && SI->getOperand(2) == Op1) || + (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1)) + return nullptr; + } } - return nullptr; + + Value *NewTV = foldOperationIntoSelectOperand(Op, TV, this); + Value *NewFV = foldOperationIntoSelectOperand(Op, FV, this); + return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI); } /// Given a binary operator, cast instruction, or select which has a PHI node as @@ -877,7 +871,7 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { // If the incoming non-constant value is in I's block, we will remove one // instruction, but insert another equivalent one, leading to infinite // instcombine. - if (isPotentiallyReachable(I.getParent(), NonConstBB, DT, LI)) + if (isPotentiallyReachable(I.getParent(), NonConstBB, &DT, LI)) return nullptr; } @@ -1379,7 +1373,8 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) { Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end()); - if (Value *V = SimplifyGEPInst(GEP.getSourceElementType(), Ops, DL, TLI, DT, AC)) + if (Value *V = + SimplifyGEPInst(GEP.getSourceElementType(), Ops, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(GEP, V); Value *PtrOp = GEP.getOperand(0); @@ -1394,7 +1389,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { for (User::op_iterator I = GEP.op_begin() + 1, E = GEP.op_end(); I != E; ++I, ++GTI) { // Skip indices into struct types. - if (isa<StructType>(*GTI)) + if (GTI.isStruct()) continue; // Index type should have the same width as IntPtr @@ -1551,7 +1546,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { bool EndsWithSequential = false; for (gep_type_iterator I = gep_type_begin(*Src), E = gep_type_end(*Src); I != E; ++I) - EndsWithSequential = !(*I)->isStructTy(); + EndsWithSequential = I.isSequential(); // Can we combine the two pointer arithmetics offsets? if (EndsWithSequential) { @@ -1860,7 +1855,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (!Offset) { // If the bitcast is of an allocation, and the allocation will be // converted to match the type of the cast, don't touch this. - if (isa<AllocaInst>(Operand) || isAllocationFn(Operand, TLI)) { + if (isa<AllocaInst>(Operand) || isAllocationFn(Operand, &TLI)) { // See if the bitcast simplifies, if so, don't nuke this GEP yet. if (Instruction *I = visitBitCast(*BCI)) { if (I != BCI) { @@ -1898,6 +1893,25 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { } } + if (!GEP.isInBounds()) { + unsigned PtrWidth = + DL.getPointerSizeInBits(PtrOp->getType()->getPointerAddressSpace()); + APInt BasePtrOffset(PtrWidth, 0); + Value *UnderlyingPtrOp = + PtrOp->stripAndAccumulateInBoundsConstantOffsets(DL, + BasePtrOffset); + if (auto *AI = dyn_cast<AllocaInst>(UnderlyingPtrOp)) { + if (GEP.accumulateConstantOffset(DL, BasePtrOffset) && + BasePtrOffset.isNonNegative()) { + APInt AllocSize(PtrWidth, DL.getTypeAllocSize(AI->getAllocatedType())); + if (BasePtrOffset.ule(AllocSize)) { + return GetElementPtrInst::CreateInBounds( + PtrOp, makeArrayRef(Ops).slice(1), GEP.getName()); + } + } + } + } + return nullptr; } @@ -1963,8 +1977,8 @@ isAllocSiteRemovable(Instruction *AI, SmallVectorImpl<WeakVH> &Users, MemIntrinsic *MI = cast<MemIntrinsic>(II); if (MI->isVolatile() || MI->getRawDest() != PI) return false; + LLVM_FALLTHROUGH; } - // fall through case Intrinsic::dbg_declare: case Intrinsic::dbg_value: case Intrinsic::invariant_start: @@ -2002,7 +2016,7 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { // to null and free calls, delete the calls and replace the comparisons with // true or false as appropriate. SmallVector<WeakVH, 64> Users; - if (isAllocSiteRemovable(&MI, Users, TLI)) { + if (isAllocSiteRemovable(&MI, Users, &TLI)) { for (unsigned i = 0, e = Users.size(); i != e; ++i) { // Lowering all @llvm.objectsize calls first because they may // use a bitcast/GEP of the alloca we are removing. @@ -2013,12 +2027,9 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { if (II->getIntrinsicID() == Intrinsic::objectsize) { - uint64_t Size; - if (!getObjectSize(II->getArgOperand(0), Size, DL, TLI)) { - ConstantInt *CI = cast<ConstantInt>(II->getArgOperand(1)); - Size = CI->isZero() ? -1ULL : 0; - } - replaceInstUsesWith(*I, ConstantInt::get(I->getType(), Size)); + ConstantInt *Result = lowerObjectSizeCall(II, DL, &TLI, + /*MustSucceed=*/true); + replaceInstUsesWith(*I, Result); eraseInstFromFunction(*I); Users[i] = nullptr; // Skip examining in the next loop. } @@ -2218,6 +2229,20 @@ Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { Value *Cond = SI.getCondition(); + Value *Op0; + ConstantInt *AddRHS; + if (match(Cond, m_Add(m_Value(Op0), m_ConstantInt(AddRHS)))) { + // Change 'switch (X+4) case 1:' into 'switch (X) case -3'. + for (SwitchInst::CaseIt CaseIter : SI.cases()) { + Constant *NewCase = ConstantExpr::getSub(CaseIter.getCaseValue(), AddRHS); + assert(isa<ConstantInt>(NewCase) && + "Result of expression should be constant"); + CaseIter.setValue(cast<ConstantInt>(NewCase)); + } + SI.setCondition(Op0); + return &SI; + } + unsigned BitWidth = cast<IntegerType>(Cond->getType())->getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); computeKnownBits(Cond, KnownZero, KnownOne, 0, &SI); @@ -2238,43 +2263,20 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { // Shrink the condition operand if the new type is smaller than the old type. // This may produce a non-standard type for the switch, but that's ok because // the backend should extend back to a legal type for the target. - bool TruncCond = false; if (NewWidth > 0 && NewWidth < BitWidth) { - TruncCond = true; IntegerType *Ty = IntegerType::get(SI.getContext(), NewWidth); Builder->SetInsertPoint(&SI); Value *NewCond = Builder->CreateTrunc(Cond, Ty, "trunc"); SI.setCondition(NewCond); - for (auto &C : SI.cases()) - static_cast<SwitchInst::CaseIt *>(&C)->setValue(ConstantInt::get( - SI.getContext(), C.getCaseValue()->getValue().trunc(NewWidth))); - } - - ConstantInt *AddRHS = nullptr; - if (match(Cond, m_Add(m_Value(), m_ConstantInt(AddRHS)))) { - Instruction *I = cast<Instruction>(Cond); - // Change 'switch (X+4) case 1:' into 'switch (X) case -3'. - for (SwitchInst::CaseIt i = SI.case_begin(), e = SI.case_end(); i != e; - ++i) { - ConstantInt *CaseVal = i.getCaseValue(); - Constant *LHS = CaseVal; - if (TruncCond) { - LHS = LeadingKnownZeros - ? ConstantExpr::getZExt(CaseVal, Cond->getType()) - : ConstantExpr::getSExt(CaseVal, Cond->getType()); - } - Constant *NewCaseVal = ConstantExpr::getSub(LHS, AddRHS); - assert(isa<ConstantInt>(NewCaseVal) && - "Result of expression should be constant"); - i.setValue(cast<ConstantInt>(NewCaseVal)); + for (SwitchInst::CaseIt CaseIter : SI.cases()) { + APInt TruncatedCase = CaseIter.getCaseValue()->getValue().trunc(NewWidth); + CaseIter.setValue(ConstantInt::get(SI.getContext(), TruncatedCase)); } - SI.setCondition(I->getOperand(0)); - Worklist.Add(I); return &SI; } - return TruncCond ? &SI : nullptr; + return nullptr; } Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { @@ -2284,7 +2286,7 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { return replaceInstUsesWith(EV, Agg); if (Value *V = - SimplifyExtractValueInst(Agg, EV.getIndices(), DL, TLI, DT, AC)) + SimplifyExtractValueInst(Agg, EV.getIndices(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(EV, V); if (InsertValueInst *IV = dyn_cast<InsertValueInst>(Agg)) { @@ -2560,7 +2562,7 @@ Instruction *InstCombiner::visitLandingPadInst(LandingPadInst &LI) { // remove it from the filter. An unexpected type handler may be // set up for a call site which throws an exception of the same // type caught. In order for the exception thrown by the unexpected - // handler to propogate correctly, the filter must be correctly + // handler to propagate correctly, the filter must be correctly // described for the call site. // // Example: @@ -2813,7 +2815,7 @@ bool InstCombiner::run() { if (I == nullptr) continue; // skip null values. // Check to see if we can DCE the instruction. - if (isInstructionTriviallyDead(I, TLI)) { + if (isInstructionTriviallyDead(I, &TLI)) { DEBUG(dbgs() << "IC: DCE: " << *I << '\n'); eraseInstFromFunction(*I); ++NumDeadInst; @@ -2824,13 +2826,13 @@ bool InstCombiner::run() { // Instruction isn't dead, see if we can constant propagate it. if (!I->use_empty() && (I->getNumOperands() == 0 || isa<Constant>(I->getOperand(0)))) { - if (Constant *C = ConstantFoldInstruction(I, DL, TLI)) { + if (Constant *C = ConstantFoldInstruction(I, DL, &TLI)) { DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << *I << '\n'); // Add operands to the worklist. replaceInstUsesWith(*I, C); ++NumConstProp; - if (isInstructionTriviallyDead(I, TLI)) + if (isInstructionTriviallyDead(I, &TLI)) eraseInstFromFunction(*I); MadeIRChange = true; continue; @@ -2839,20 +2841,21 @@ bool InstCombiner::run() { // In general, it is possible for computeKnownBits to determine all bits in // a value even when the operands are not all constants. - if (ExpensiveCombines && !I->use_empty() && I->getType()->isIntegerTy()) { - unsigned BitWidth = I->getType()->getScalarSizeInBits(); + Type *Ty = I->getType(); + if (ExpensiveCombines && !I->use_empty() && Ty->isIntOrIntVectorTy()) { + unsigned BitWidth = Ty->getScalarSizeInBits(); APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); computeKnownBits(I, KnownZero, KnownOne, /*Depth*/0, I); if ((KnownZero | KnownOne).isAllOnesValue()) { - Constant *C = ConstantInt::get(I->getContext(), KnownOne); + Constant *C = ConstantInt::get(Ty, KnownOne); DEBUG(dbgs() << "IC: ConstFold (all bits known) to: " << *C << " from: " << *I << '\n'); // Add operands to the worklist. replaceInstUsesWith(*I, C); ++NumConstProp; - if (isInstructionTriviallyDead(I, TLI)) + if (isInstructionTriviallyDead(I, &TLI)) eraseInstFromFunction(*I); MadeIRChange = true; continue; @@ -2883,7 +2886,7 @@ bool InstCombiner::run() { // If the user is one of our immediate successors, and if that successor // only has us as a predecessors (we'd have to split the critical edge // otherwise), we can keep going. - if (UserIsSuccessor && UserParent->getSinglePredecessor()) { + if (UserIsSuccessor && UserParent->getUniquePredecessor()) { // Okay, the CFG is simple enough, try to sink this instruction. if (TryToSinkInstruction(I, UserParent)) { DEBUG(dbgs() << "IC: Sink: " << *I << '\n'); @@ -2941,14 +2944,12 @@ bool InstCombiner::run() { eraseInstFromFunction(*I); } else { -#ifndef NDEBUG DEBUG(dbgs() << "IC: Mod = " << OrigI << '\n' << " New = " << *I << '\n'); -#endif // If the instruction was modified, it's possible that it is now dead. // if so, remove it. - if (isInstructionTriviallyDead(I, TLI)) { + if (isInstructionTriviallyDead(I, &TLI)) { eraseInstFromFunction(*I); } else { Worklist.Add(I); @@ -2981,7 +2982,7 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, Worklist.push_back(BB); SmallVector<Instruction*, 128> InstrsForInstCombineWorklist; - DenseMap<ConstantExpr*, Constant*> FoldedConstants; + DenseMap<Constant *, Constant *> FoldedConstants; do { BB = Worklist.pop_back_val(); @@ -3017,17 +3018,17 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, // See if we can constant fold its operands. for (User::op_iterator i = Inst->op_begin(), e = Inst->op_end(); i != e; ++i) { - ConstantExpr *CE = dyn_cast<ConstantExpr>(i); - if (CE == nullptr) + if (!isa<ConstantVector>(i) && !isa<ConstantExpr>(i)) continue; - Constant *&FoldRes = FoldedConstants[CE]; + auto *C = cast<Constant>(i); + Constant *&FoldRes = FoldedConstants[C]; if (!FoldRes) - FoldRes = ConstantFoldConstantExpression(CE, DL, TLI); + FoldRes = ConstantFoldConstant(C, DL, TLI); if (!FoldRes) - FoldRes = CE; + FoldRes = C; - if (FoldRes != CE) { + if (FoldRes != C) { *i = FoldRes; MadeIRChange = true; } @@ -3120,8 +3121,15 @@ combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist, /// Builder - This is an IRBuilder that automatically inserts new /// instructions into the worklist when they are created. - IRBuilder<TargetFolder, InstCombineIRInserter> Builder( - F.getContext(), TargetFolder(DL), InstCombineIRInserter(Worklist, &AC)); + IRBuilder<TargetFolder, IRBuilderCallbackInserter> Builder( + F.getContext(), TargetFolder(DL), + IRBuilderCallbackInserter([&Worklist, &AC](Instruction *I) { + Worklist.Add(I); + + using namespace llvm::PatternMatch; + if (match(I, m_Intrinsic<Intrinsic::assume>())) + AC.registerAssumption(cast<CallInst>(I)); + })); // Lower dbg.declare intrinsics otherwise their value may be clobbered // by instcombiner. @@ -3137,7 +3145,7 @@ combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist, bool Changed = prepareICWorklistFromFunction(F, DL, &TLI, Worklist); InstCombiner IC(Worklist, &Builder, F.optForMinSize(), ExpensiveCombines, - AA, &AC, &TLI, &DT, DL, LI); + AA, AC, TLI, DT, DL, LI); Changed |= IC.run(); if (!Changed) @@ -3148,7 +3156,7 @@ combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist, } PreservedAnalyses InstCombinePass::run(Function &F, - AnalysisManager<Function> &AM) { + FunctionAnalysisManager &AM) { auto &AC = AM.getResult<AssumptionAnalysis>(F); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); |
