diff options
Diffstat (limited to 'lib/Transforms/InstCombine')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineAddSub.cpp | 93 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 791 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCalls.cpp | 1515 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCasts.cpp | 46 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCompares.cpp | 983 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineInternal.h | 32 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp | 283 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineMulDivRem.cpp | 165 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombinePHI.cpp | 101 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineSelect.cpp | 379 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineShifts.cpp | 151 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp | 146 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineVectorOps.cpp | 84 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstructionCombining.cpp | 381 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/Makefile | 15 |
15 files changed, 3278 insertions, 1887 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 6f49399f57bf7..221a220071738 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -58,7 +58,6 @@ namespace { // operators inevitably call FAddendCoef's constructor which is not cheap. void operator=(const FAddendCoef &A); void operator+=(const FAddendCoef &A); - void operator-=(const FAddendCoef &A); void operator*=(const FAddendCoef &S); bool isOne() const { return isInt() && IntVal == 1; } @@ -123,11 +122,18 @@ namespace { bool isConstant() const { return Val == nullptr; } bool isZero() const { return Coeff.isZero(); } - void set(short Coefficient, Value *V) { Coeff.set(Coefficient), Val = V; } - void set(const APFloat& Coefficient, Value *V) - { Coeff.set(Coefficient); Val = V; } - void set(const ConstantFP* Coefficient, Value *V) - { Coeff.set(Coefficient->getValueAPF()); Val = V; } + void set(short Coefficient, Value *V) { + Coeff.set(Coefficient); + Val = V; + } + void set(const APFloat &Coefficient, Value *V) { + Coeff.set(Coefficient); + Val = V; + } + void set(const ConstantFP *Coefficient, Value *V) { + Coeff.set(Coefficient->getValueAPF()); + Val = V; + } void negate() { Coeff.negate(); } @@ -272,27 +278,6 @@ void FAddendCoef::operator+=(const FAddendCoef &That) { T.add(createAPFloatFromInt(T.getSemantics(), That.IntVal), RndMode); } -void FAddendCoef::operator-=(const FAddendCoef &That) { - enum APFloat::roundingMode RndMode = APFloat::rmNearestTiesToEven; - if (isInt() == That.isInt()) { - if (isInt()) - IntVal -= That.IntVal; - else - getFpVal().subtract(That.getFpVal(), RndMode); - return; - } - - if (isInt()) { - const APFloat &T = That.getFpVal(); - convertToFpType(T.getSemantics()); - getFpVal().subtract(T, RndMode); - return; - } - - APFloat &T = getFpVal(); - T.subtract(createAPFloatFromInt(T.getSemantics(), IntVal), RndMode); -} - void FAddendCoef::operator*=(const FAddendCoef &That) { if (That.isOne()) return; @@ -321,8 +306,6 @@ void FAddendCoef::operator*=(const FAddendCoef &That) { APFloat::rmNearestTiesToEven); else F0.multiply(That.getFpVal(), APFloat::rmNearestTiesToEven); - - return; } void FAddendCoef::negate() { @@ -716,10 +699,9 @@ Value *FAddCombine::createNaryFAdd bool LastValNeedNeg = false; // Iterate the addends, creating fadd/fsub using adjacent two addends. - for (AddendVect::const_iterator I = Opnds.begin(), E = Opnds.end(); - I != E; I++) { + for (const FAddend *Opnd : Opnds) { bool NeedNeg; - Value *V = createAddendVal(**I, NeedNeg); + Value *V = createAddendVal(*Opnd, NeedNeg); if (!LastVal) { LastVal = V; LastValNeedNeg = NeedNeg; @@ -808,9 +790,7 @@ unsigned FAddCombine::calcInstrNumber(const AddendVect &Opnds) { unsigned NegOpndNum = 0; // Adjust the number of instructions needed to emit the N-ary add. - for (AddendVect::const_iterator I = Opnds.begin(), E = Opnds.end(); - I != E; I++) { - const FAddend *Opnd = *I; + for (const FAddend *Opnd : Opnds) { if (Opnd->isConstant()) continue; @@ -1052,22 +1032,26 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // (A*B)+(A*C) -> A*(B+C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); - if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { + const APInt *Val; + if (match(RHS, m_APInt(Val))) { // X + (signbit) --> X ^ signbit - const APInt &Val = CI->getValue(); - if (Val.isSignBit()) + if (Val->isSignBit()) return BinaryOperator::CreateXor(LHS, RHS); + } + // FIXME: Use the match above instead of dyn_cast to allow these transforms + // for splat vectors. + if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { // See if SimplifyDemandedBits can simplify this. This handles stuff like // (X & 254)+1 -> (X&254)|1 if (SimplifyDemandedInstructionBits(I)) @@ -1157,7 +1141,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return BinaryOperator::CreateSub(LHS, V); if (Value *V = checkForNegativeOperand(I, Builder)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // A+B --> A|B iff A and B have no bits set in common. if (haveNoCommonBitsSet(LHS, RHS, DL, AC, &I, DT)) @@ -1169,6 +1153,9 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return BinaryOperator::CreateSub(SubOne(CRHS), X); } + // FIXME: We already did a check for ConstantInt RHS above this. + // FIXME: Is this pattern covered by another fold? No regression tests fail on + // removal. if (ConstantInt *CRHS = dyn_cast<ConstantInt>(RHS)) { // (X & FF00) + xx00 -> (X+xx00) & FF00 Value *X; @@ -1317,11 +1304,11 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (isa<Constant>(RHS)) { if (isa<PHINode>(LHS)) @@ -1415,7 +1402,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { if (I.hasUnsafeAlgebra()) { if (Value *V = FAddCombine(Builder).simplify(&I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); } return Changed ? &I : nullptr; @@ -1493,15 +1480,15 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifySubInst(Op0, Op1, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // (A*B)-(A*C) -> A*(B-C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // If this is a 'B = x-(-A)', change to B = x+A. if (Value *V = dyn_castNegVal(Op1)) { @@ -1667,13 +1654,13 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (match(Op0, m_PtrToInt(m_Value(LHSOp))) && match(Op1, m_PtrToInt(m_Value(RHSOp)))) if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType())) - return ReplaceInstUsesWith(I, Res); + return replaceInstUsesWith(I, Res); // trunc(p)-trunc(q) -> trunc(p-q) if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) && match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp))))) if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType())) - return ReplaceInstUsesWith(I, Res); + return replaceInstUsesWith(I, Res); bool Changed = false; if (!I.hasNoSignedWrap() && WillNotOverflowSignedSub(Op0, Op1, I)) { @@ -1692,11 +1679,11 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // fsub nsz 0, X ==> fsub nsz -0.0, X if (I.getFastMathFlags().noSignedZeros() && match(Op0, m_Zero())) { @@ -1736,7 +1723,7 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (I.hasUnsafeAlgebra()) { if (Value *V = FAddCombine(Builder).simplify(&I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); } return nullptr; diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 76cefd97cd8f1..1a6459b3d689a 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -39,30 +39,29 @@ static inline Value *dyn_castNotVal(Value *V) { } /// Similar to getICmpCode but for FCmpInst. This encodes a fcmp predicate into -/// a three bit mask. It also returns whether it is an ordered predicate by -/// reference. -static unsigned getFCmpCode(FCmpInst::Predicate CC, bool &isOrdered) { - isOrdered = false; - switch (CC) { - case FCmpInst::FCMP_ORD: isOrdered = true; return 0; // 000 - case FCmpInst::FCMP_UNO: return 0; // 000 - case FCmpInst::FCMP_OGT: isOrdered = true; return 1; // 001 - case FCmpInst::FCMP_UGT: return 1; // 001 - case FCmpInst::FCMP_OEQ: isOrdered = true; return 2; // 010 - case FCmpInst::FCMP_UEQ: return 2; // 010 - case FCmpInst::FCMP_OGE: isOrdered = true; return 3; // 011 - case FCmpInst::FCMP_UGE: return 3; // 011 - case FCmpInst::FCMP_OLT: isOrdered = true; return 4; // 100 - case FCmpInst::FCMP_ULT: return 4; // 100 - case FCmpInst::FCMP_ONE: isOrdered = true; return 5; // 101 - case FCmpInst::FCMP_UNE: return 5; // 101 - case FCmpInst::FCMP_OLE: isOrdered = true; return 6; // 110 - case FCmpInst::FCMP_ULE: return 6; // 110 - // True -> 7 - default: - // Not expecting FCMP_FALSE and FCMP_TRUE; - llvm_unreachable("Unexpected FCmp predicate!"); - } +/// a four bit mask. +static unsigned getFCmpCode(FCmpInst::Predicate CC) { + assert(FCmpInst::FCMP_FALSE <= CC && CC <= FCmpInst::FCMP_TRUE && + "Unexpected FCmp predicate!"); + // Take advantage of the bit pattern of FCmpInst::Predicate here. + // U L G E + static_assert(FCmpInst::FCMP_FALSE == 0, ""); // 0 0 0 0 + static_assert(FCmpInst::FCMP_OEQ == 1, ""); // 0 0 0 1 + static_assert(FCmpInst::FCMP_OGT == 2, ""); // 0 0 1 0 + static_assert(FCmpInst::FCMP_OGE == 3, ""); // 0 0 1 1 + static_assert(FCmpInst::FCMP_OLT == 4, ""); // 0 1 0 0 + static_assert(FCmpInst::FCMP_OLE == 5, ""); // 0 1 0 1 + static_assert(FCmpInst::FCMP_ONE == 6, ""); // 0 1 1 0 + static_assert(FCmpInst::FCMP_ORD == 7, ""); // 0 1 1 1 + static_assert(FCmpInst::FCMP_UNO == 8, ""); // 1 0 0 0 + static_assert(FCmpInst::FCMP_UEQ == 9, ""); // 1 0 0 1 + static_assert(FCmpInst::FCMP_UGT == 10, ""); // 1 0 1 0 + static_assert(FCmpInst::FCMP_UGE == 11, ""); // 1 0 1 1 + static_assert(FCmpInst::FCMP_ULT == 12, ""); // 1 1 0 0 + static_assert(FCmpInst::FCMP_ULE == 13, ""); // 1 1 0 1 + static_assert(FCmpInst::FCMP_UNE == 14, ""); // 1 1 1 0 + static_assert(FCmpInst::FCMP_TRUE == 15, ""); // 1 1 1 1 + return CC; } /// This is the complement of getICmpCode, which turns an opcode and two @@ -78,26 +77,16 @@ static Value *getNewICmpValue(bool Sign, unsigned Code, Value *LHS, Value *RHS, } /// This is the complement of getFCmpCode, which turns an opcode and two -/// operands into either a FCmp instruction. isordered is passed in to determine -/// which kind of predicate to use in the new fcmp instruction. -static Value *getFCmpValue(bool isordered, unsigned code, - Value *LHS, Value *RHS, +/// operands into either a FCmp instruction, or a true/false constant. +static Value *getFCmpValue(unsigned Code, Value *LHS, Value *RHS, InstCombiner::BuilderTy *Builder) { - CmpInst::Predicate Pred; - switch (code) { - default: llvm_unreachable("Illegal FCmp code!"); - case 0: Pred = isordered ? FCmpInst::FCMP_ORD : FCmpInst::FCMP_UNO; break; - case 1: Pred = isordered ? FCmpInst::FCMP_OGT : FCmpInst::FCMP_UGT; break; - case 2: Pred = isordered ? FCmpInst::FCMP_OEQ : FCmpInst::FCMP_UEQ; break; - case 3: Pred = isordered ? FCmpInst::FCMP_OGE : FCmpInst::FCMP_UGE; break; - case 4: Pred = isordered ? FCmpInst::FCMP_OLT : FCmpInst::FCMP_ULT; break; - case 5: Pred = isordered ? FCmpInst::FCMP_ONE : FCmpInst::FCMP_UNE; break; - case 6: Pred = isordered ? FCmpInst::FCMP_OLE : FCmpInst::FCMP_ULE; break; - case 7: - if (!isordered) - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 1); - Pred = FCmpInst::FCMP_ORD; break; - } + const auto Pred = static_cast<FCmpInst::Predicate>(Code); + assert(FCmpInst::FCMP_FALSE <= Pred && Pred <= FCmpInst::FCMP_TRUE && + "Unexpected FCmp predicate!"); + if (Pred == FCmpInst::FCMP_FALSE) + return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); + if (Pred == FCmpInst::FCMP_TRUE) + return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 1); return Builder->CreateFCmp(Pred, LHS, RHS); } @@ -243,7 +232,7 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op, if (CI->getValue() == ShlMask) // Masking out bits that the shift already masks. - return ReplaceInstUsesWith(TheAnd, Op); // No need for the and. + return replaceInstUsesWith(TheAnd, Op); // No need for the and. if (CI != AndRHS) { // Reducing bits set in and. TheAnd.setOperand(1, CI); @@ -263,7 +252,7 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op, if (CI->getValue() == ShrMask) // Masking out bits that the shift already masks. - return ReplaceInstUsesWith(TheAnd, Op); + return replaceInstUsesWith(TheAnd, Op); if (CI != AndRHS) { TheAnd.setOperand(1, CI); // Reduce bits set in and cst. @@ -465,11 +454,9 @@ static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C, if (CCst && CCst->isZero()) { // if C is zero, then both A and B qualify as mask result |= (icmp_eq ? (FoldMskICmp_Mask_AllZeroes | - FoldMskICmp_Mask_AllZeroes | FoldMskICmp_AMask_Mixed | FoldMskICmp_BMask_Mixed) : (FoldMskICmp_Mask_NotAllZeroes | - FoldMskICmp_Mask_NotAllZeroes | FoldMskICmp_AMask_NotMixed | FoldMskICmp_BMask_NotMixed)); if (icmp_abit) @@ -666,7 +653,7 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, if (!ICmpInst::isEquality(RHSCC)) return 0; - // Look for ANDs in on the right side of the RHS icmp. + // Look for ANDs on the right side of the RHS icmp. if (!ok && R2->getType()->isIntegerTy()) { if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) { R11 = R2; @@ -694,9 +681,9 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, B = L21; C = L1; } - unsigned left_type = getTypeOfMaskedICmp(A, B, C, LHSCC); - unsigned right_type = getTypeOfMaskedICmp(A, D, E, RHSCC); - return left_type & right_type; + unsigned LeftType = getTypeOfMaskedICmp(A, B, C, LHSCC); + unsigned RightType = getTypeOfMaskedICmp(A, D, E, RHSCC); + return LeftType & RightType; } /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) @@ -705,9 +692,9 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, llvm::InstCombiner::BuilderTy *Builder) { Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); - unsigned mask = foldLogOpOfMaskedICmpsHelper(A, B, C, D, E, LHS, RHS, + unsigned Mask = foldLogOpOfMaskedICmpsHelper(A, B, C, D, E, LHS, RHS, LHSCC, RHSCC); - if (mask == 0) return nullptr; + if (Mask == 0) return nullptr; assert(ICmpInst::isEquality(LHSCC) && ICmpInst::isEquality(RHSCC) && "foldLogOpOfMaskedICmpsHelper must return an equality predicate."); @@ -723,48 +710,48 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // input and output). // In most cases we're going to produce an EQ for the "&&" case. - ICmpInst::Predicate NEWCC = IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; + ICmpInst::Predicate NewCC = IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; if (!IsAnd) { // Convert the masking analysis into its equivalent with negated // comparisons. - mask = conjugateICmpMask(mask); + Mask = conjugateICmpMask(Mask); } - if (mask & FoldMskICmp_Mask_AllZeroes) { + if (Mask & FoldMskICmp_Mask_AllZeroes) { // (icmp eq (A & B), 0) & (icmp eq (A & D), 0) // -> (icmp eq (A & (B|D)), 0) - Value *newOr = Builder->CreateOr(B, D); - Value *newAnd = Builder->CreateAnd(A, newOr); - // we can't use C as zero, because we might actually handle + Value *NewOr = Builder->CreateOr(B, D); + Value *NewAnd = Builder->CreateAnd(A, NewOr); + // We can't use C as zero because we might actually handle // (icmp ne (A & B), B) & (icmp ne (A & D), D) - // with B and D, having a single bit set - Value *zero = Constant::getNullValue(A->getType()); - return Builder->CreateICmp(NEWCC, newAnd, zero); + // with B and D, having a single bit set. + Value *Zero = Constant::getNullValue(A->getType()); + return Builder->CreateICmp(NewCC, NewAnd, Zero); } - if (mask & FoldMskICmp_BMask_AllOnes) { + if (Mask & FoldMskICmp_BMask_AllOnes) { // (icmp eq (A & B), B) & (icmp eq (A & D), D) // -> (icmp eq (A & (B|D)), (B|D)) - Value *newOr = Builder->CreateOr(B, D); - Value *newAnd = Builder->CreateAnd(A, newOr); - return Builder->CreateICmp(NEWCC, newAnd, newOr); + Value *NewOr = Builder->CreateOr(B, D); + Value *NewAnd = Builder->CreateAnd(A, NewOr); + return Builder->CreateICmp(NewCC, NewAnd, NewOr); } - if (mask & FoldMskICmp_AMask_AllOnes) { + if (Mask & FoldMskICmp_AMask_AllOnes) { // (icmp eq (A & B), A) & (icmp eq (A & D), A) // -> (icmp eq (A & (B&D)), A) - Value *newAnd1 = Builder->CreateAnd(B, D); - Value *newAnd = Builder->CreateAnd(A, newAnd1); - return Builder->CreateICmp(NEWCC, newAnd, A); + Value *NewAnd1 = Builder->CreateAnd(B, D); + Value *NewAnd2 = Builder->CreateAnd(A, NewAnd1); + return Builder->CreateICmp(NewCC, NewAnd2, A); } // Remaining cases assume at least that B and D are constant, and depend on - // their actual values. This isn't strictly, necessary, just a "handle the + // their actual values. This isn't strictly necessary, just a "handle the // easy cases for now" decision. ConstantInt *BCst = dyn_cast<ConstantInt>(B); if (!BCst) return nullptr; ConstantInt *DCst = dyn_cast<ConstantInt>(D); if (!DCst) return nullptr; - if (mask & (FoldMskICmp_Mask_NotAllZeroes | FoldMskICmp_BMask_NotAllOnes)) { + if (Mask & (FoldMskICmp_Mask_NotAllZeroes | FoldMskICmp_BMask_NotAllOnes)) { // (icmp ne (A & B), 0) & (icmp ne (A & D), 0) and // (icmp ne (A & B), B) & (icmp ne (A & D), D) // -> (icmp ne (A & B), 0) or (icmp ne (A & D), 0) @@ -777,7 +764,7 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, else if (NewMask == DCst->getValue()) return RHS; } - if (mask & FoldMskICmp_AMask_NotAllOnes) { + if (Mask & FoldMskICmp_AMask_NotAllOnes) { // (icmp ne (A & B), B) & (icmp ne (A & D), D) // -> (icmp ne (A & B), A) or (icmp ne (A & D), A) // Only valid if one of the masks is a superset of the other (check "B|D" is @@ -789,7 +776,7 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, else if (NewMask == DCst->getValue()) return RHS; } - if (mask & FoldMskICmp_BMask_Mixed) { + if (Mask & FoldMskICmp_BMask_Mixed) { // (icmp eq (A & B), C) & (icmp eq (A & D), E) // We already know that B & C == C && D & E == E. // If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of @@ -797,26 +784,26 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // contradict, then we can transform to // -> (icmp eq (A & (B|D)), (C|E)) // Currently, we only handle the case of B, C, D, and E being constant. - // we can't simply use C and E, because we might actually handle + // We can't simply use C and E because we might actually handle // (icmp ne (A & B), B) & (icmp eq (A & D), D) - // with B and D, having a single bit set + // with B and D, having a single bit set. ConstantInt *CCst = dyn_cast<ConstantInt>(C); if (!CCst) return nullptr; ConstantInt *ECst = dyn_cast<ConstantInt>(E); if (!ECst) return nullptr; - if (LHSCC != NEWCC) + if (LHSCC != NewCC) CCst = cast<ConstantInt>(ConstantExpr::getXor(BCst, CCst)); - if (RHSCC != NEWCC) + if (RHSCC != NewCC) ECst = cast<ConstantInt>(ConstantExpr::getXor(DCst, ECst)); - // if there is a conflict we should actually return a false for the - // whole construct + // If there is a conflict, we should actually return a false for the + // whole construct. if (((BCst->getValue() & DCst->getValue()) & (CCst->getValue() ^ ECst->getValue())) != 0) return ConstantInt::get(LHS->getType(), !IsAnd); - Value *newOr1 = Builder->CreateOr(B, D); - Value *newOr2 = ConstantExpr::getOr(CCst, ECst); - Value *newAnd = Builder->CreateAnd(A, newOr1); - return Builder->CreateICmp(NEWCC, newAnd, newOr2); + Value *NewOr1 = Builder->CreateOr(B, D); + Value *NewOr2 = ConstantExpr::getOr(CCst, ECst); + Value *NewAnd = Builder->CreateAnd(A, NewOr1); + return Builder->CreateICmp(NewCC, NewAnd, NewOr2); } return nullptr; } @@ -915,15 +902,10 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { if (LHSCst == RHSCst && LHSCC == RHSCC) { // (icmp ult A, C) & (icmp ult B, C) --> (icmp ult (A|B), C) - // where C is a power of 2 - if (LHSCC == ICmpInst::ICMP_ULT && - LHSCst->getValue().isPowerOf2()) { - Value *NewOr = Builder->CreateOr(Val, Val2); - return Builder->CreateICmp(LHSCC, NewOr, LHSCst); - } - + // where C is a power of 2 or // (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0) - if (LHSCC == ICmpInst::ICMP_EQ && LHSCst->isZero()) { + if ((LHSCC == ICmpInst::ICMP_ULT && LHSCst->getValue().isPowerOf2()) || + (LHSCC == ICmpInst::ICMP_EQ && LHSCst->isZero())) { Value *NewOr = Builder->CreateOr(Val, Val2); return Builder->CreateICmp(LHSCC, NewOr, LHSCst); } @@ -975,16 +957,6 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { RHSCC == ICmpInst::ICMP_SGE || RHSCC == ICmpInst::ICMP_SLE) return nullptr; - // Make a constant range that's the intersection of the two icmp ranges. - // If the intersection is empty, we know that the result is false. - ConstantRange LHSRange = - ConstantRange::makeAllowedICmpRegion(LHSCC, LHSCst->getValue()); - ConstantRange RHSRange = - ConstantRange::makeAllowedICmpRegion(RHSCC, RHSCst->getValue()); - - if (LHSRange.intersectWith(RHSRange).isEmptySet()) - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); - // We can't fold (ugt x, C) & (sgt x, C2). if (!PredicatesFoldable(LHSCC, RHSCC)) return nullptr; @@ -1124,6 +1096,29 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { /// Optimize (fcmp)&(fcmp). NOTE: Unlike the rest of instcombine, this returns /// a Value which should already be inserted into the function. Value *InstCombiner::FoldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { + Value *Op0LHS = LHS->getOperand(0), *Op0RHS = LHS->getOperand(1); + Value *Op1LHS = RHS->getOperand(0), *Op1RHS = RHS->getOperand(1); + FCmpInst::Predicate Op0CC = LHS->getPredicate(), Op1CC = RHS->getPredicate(); + + if (Op0LHS == Op1RHS && Op0RHS == Op1LHS) { + // Swap RHS operands to match LHS. + Op1CC = FCmpInst::getSwappedPredicate(Op1CC); + std::swap(Op1LHS, Op1RHS); + } + + // Simplify (fcmp cc0 x, y) & (fcmp cc1 x, y). + // Suppose the relation between x and y is R, where R is one of + // U(1000), L(0100), G(0010) or E(0001), and CC0 and CC1 are the bitmasks for + // testing the desired relations. + // + // Since (R & CC0) and (R & CC1) are either R or 0, we actually have this: + // bool(R & CC0) && bool(R & CC1) + // = bool((R & CC0) & (R & CC1)) + // = bool(R & (CC0 & CC1)) <= by re-association, commutation, and idempotency + if (Op0LHS == Op1LHS && Op0RHS == Op1RHS) + return getFCmpValue(getFCmpCode(Op0CC) & getFCmpCode(Op1CC), Op0LHS, Op0RHS, + Builder); + if (LHS->getPredicate() == FCmpInst::FCMP_ORD && RHS->getPredicate() == FCmpInst::FCMP_ORD) { if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) @@ -1147,56 +1142,6 @@ Value *InstCombiner::FoldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { return nullptr; } - Value *Op0LHS = LHS->getOperand(0), *Op0RHS = LHS->getOperand(1); - Value *Op1LHS = RHS->getOperand(0), *Op1RHS = RHS->getOperand(1); - FCmpInst::Predicate Op0CC = LHS->getPredicate(), Op1CC = RHS->getPredicate(); - - - if (Op0LHS == Op1RHS && Op0RHS == Op1LHS) { - // Swap RHS operands to match LHS. - Op1CC = FCmpInst::getSwappedPredicate(Op1CC); - std::swap(Op1LHS, Op1RHS); - } - - if (Op0LHS == Op1LHS && Op0RHS == Op1RHS) { - // Simplify (fcmp cc0 x, y) & (fcmp cc1 x, y). - if (Op0CC == Op1CC) - return Builder->CreateFCmp((FCmpInst::Predicate)Op0CC, Op0LHS, Op0RHS); - if (Op0CC == FCmpInst::FCMP_FALSE || Op1CC == FCmpInst::FCMP_FALSE) - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); - if (Op0CC == FCmpInst::FCMP_TRUE) - return RHS; - if (Op1CC == FCmpInst::FCMP_TRUE) - return LHS; - - bool Op0Ordered; - bool Op1Ordered; - unsigned Op0Pred = getFCmpCode(Op0CC, Op0Ordered); - unsigned Op1Pred = getFCmpCode(Op1CC, Op1Ordered); - // uno && ord -> false - if (Op0Pred == 0 && Op1Pred == 0 && Op0Ordered != Op1Ordered) - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); - if (Op1Pred == 0) { - std::swap(LHS, RHS); - std::swap(Op0Pred, Op1Pred); - std::swap(Op0Ordered, Op1Ordered); - } - if (Op0Pred == 0) { - // uno && ueq -> uno && (uno || eq) -> uno - // ord && olt -> ord && (ord && lt) -> olt - if (!Op0Ordered && (Op0Ordered == Op1Ordered)) - return LHS; - if (Op0Ordered && (Op0Ordered == Op1Ordered)) - return RHS; - - // uno && oeq -> uno && (ord && eq) -> false - if (!Op0Ordered) - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); - // ord && ueq -> ord && (uno || eq) -> oeq - return getFCmpValue(true, Op1Pred, Op0LHS, Op0RHS, Builder); - } - } - return nullptr; } @@ -1248,19 +1193,131 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I, return nullptr; } +Instruction *InstCombiner::foldCastedBitwiseLogic(BinaryOperator &I) { + auto LogicOpc = I.getOpcode(); + assert((LogicOpc == Instruction::And || LogicOpc == Instruction::Or || + LogicOpc == Instruction::Xor) && + "Unexpected opcode for bitwise logic folding"); + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + CastInst *Cast0 = dyn_cast<CastInst>(Op0); + if (!Cast0) + return nullptr; + + // This must be a cast from an integer or integer vector source type to allow + // transformation of the logic operation to the source type. + Type *DestTy = I.getType(); + Type *SrcTy = Cast0->getSrcTy(); + if (!SrcTy->isIntOrIntVectorTy()) + return nullptr; + + // If one operand is a bitcast and the other is a constant, move the logic + // operation ahead of the bitcast. That is, do the logic operation in the + // original type. This can eliminate useless bitcasts and allow normal + // combines that would otherwise be impeded by the bitcast. Canonicalization + // ensures that if there is a constant operand, it will be the second operand. + Value *BC = nullptr; + Constant *C = nullptr; + if ((match(Op0, m_BitCast(m_Value(BC))) && match(Op1, m_Constant(C)))) { + Value *NewConstant = ConstantExpr::getBitCast(C, SrcTy); + Value *NewOp = Builder->CreateBinOp(LogicOpc, BC, NewConstant, I.getName()); + return CastInst::CreateBitOrPointerCast(NewOp, DestTy); + } + + CastInst *Cast1 = dyn_cast<CastInst>(Op1); + if (!Cast1) + return nullptr; + + // Both operands of the logic operation are casts. The casts must be of the + // same type for reduction. + auto CastOpcode = Cast0->getOpcode(); + if (CastOpcode != Cast1->getOpcode() || SrcTy != Cast1->getSrcTy()) + return nullptr; + + Value *Cast0Src = Cast0->getOperand(0); + Value *Cast1Src = Cast1->getOperand(0); + + // fold (logic (cast A), (cast B)) -> (cast (logic A, B)) + + // Only do this if the casts both really cause code to be generated. + if ((!isa<ICmpInst>(Cast0Src) || !isa<ICmpInst>(Cast1Src)) && + ShouldOptimizeCast(CastOpcode, Cast0Src, DestTy) && + ShouldOptimizeCast(CastOpcode, Cast1Src, DestTy)) { + Value *NewOp = Builder->CreateBinOp(LogicOpc, Cast0Src, Cast1Src, + I.getName()); + return CastInst::Create(CastOpcode, NewOp, DestTy); + } + + // For now, only 'and'/'or' have optimizations after this. + if (LogicOpc == Instruction::Xor) + return nullptr; + + // If this is logic(cast(icmp), cast(icmp)), try to fold this even if the + // cast is otherwise not optimizable. This happens for vector sexts. + ICmpInst *ICmp0 = dyn_cast<ICmpInst>(Cast0Src); + ICmpInst *ICmp1 = dyn_cast<ICmpInst>(Cast1Src); + if (ICmp0 && ICmp1) { + Value *Res = LogicOpc == Instruction::And ? FoldAndOfICmps(ICmp0, ICmp1) + : FoldOrOfICmps(ICmp0, ICmp1, &I); + if (Res) + return CastInst::Create(CastOpcode, Res, DestTy); + return nullptr; + } + + // If this is logic(cast(fcmp), cast(fcmp)), try to fold this even if the + // cast is otherwise not optimizable. This happens for vector sexts. + FCmpInst *FCmp0 = dyn_cast<FCmpInst>(Cast0Src); + FCmpInst *FCmp1 = dyn_cast<FCmpInst>(Cast1Src); + if (FCmp0 && FCmp1) { + Value *Res = LogicOpc == Instruction::And ? FoldAndOfFCmps(FCmp0, FCmp1) + : FoldOrOfFCmps(FCmp0, FCmp1); + if (Res) + return CastInst::Create(CastOpcode, Res, DestTy); + return nullptr; + } + + return nullptr; +} + +static Instruction *foldBoolSextMaskToSelect(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // Canonicalize SExt or Not to the LHS + if (match(Op1, m_SExt(m_Value())) || match(Op1, m_Not(m_Value()))) { + std::swap(Op0, Op1); + } + + // Fold (and (sext bool to A), B) --> (select bool, B, 0) + Value *X = nullptr; + if (match(Op0, m_SExt(m_Value(X))) && + X->getType()->getScalarType()->isIntegerTy(1)) { + Value *Zero = Constant::getNullValue(Op1->getType()); + return SelectInst::Create(X, Op1, Zero); + } + + // Fold (and ~(sext bool to A), B) --> (select bool, 0, B) + if (match(Op0, m_Not(m_SExt(m_Value(X)))) && + X->getType()->getScalarType()->isIntegerTy(1)) { + Value *Zero = Constant::getNullValue(Op0->getType()); + return SelectInst::Create(X, Zero, Op1); + } + + return nullptr; +} + Instruction *InstCombiner::visitAnd(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyAndInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // (A|B)&(A|C) -> A|(B&C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. @@ -1268,7 +1325,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { return &I; if (Value *V = SimplifyBSwap(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(Op1)) { const APInt &AndRHSMask = AndRHS->getValue(); @@ -1399,8 +1456,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { { Value *tmpOp0 = Op0; Value *tmpOp1 = Op1; - if (Op0->hasOneUse() && - match(Op0, m_Xor(m_Value(A), m_Value(B)))) { + if (match(Op0, m_OneUse(m_Xor(m_Value(A), m_Value(B))))) { if (A == Op1 || B == Op1 ) { tmpOp1 = Op0; tmpOp0 = Op1; @@ -1408,12 +1464,11 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { } } - if (tmpOp1->hasOneUse() && - match(tmpOp1, m_Xor(m_Value(A), m_Value(B)))) { + if (match(tmpOp1, m_OneUse(m_Xor(m_Value(A), m_Value(B))))) { if (B == tmpOp0) { std::swap(A, B); } - // Notice that the patten (A&(~B)) is actually (A&(-1^B)), so if + // Notice that the pattern (A&(~B)) is actually (A&(-1^B)), so if // A is originally -1 (or a vector of -1 and undefs), then we enter // an endless loop. By checking that A is non-constant we ensure that // we will never get to the loop. @@ -1458,7 +1513,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { ICmpInst *RHS = dyn_cast<ICmpInst>(Op1); if (LHS && RHS) if (Value *Res = FoldAndOfICmps(LHS, RHS)) - return ReplaceInstUsesWith(I, Res); + return replaceInstUsesWith(I, Res); // TODO: Make this recursive; it's a little tricky because an arbitrary // number of 'and' instructions might have to be created. @@ -1466,18 +1521,18 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (LHS && match(Op1, m_OneUse(m_And(m_Value(X), m_Value(Y))))) { if (auto *Cmp = dyn_cast<ICmpInst>(X)) if (Value *Res = FoldAndOfICmps(LHS, Cmp)) - return ReplaceInstUsesWith(I, Builder->CreateAnd(Res, Y)); + return replaceInstUsesWith(I, Builder->CreateAnd(Res, Y)); if (auto *Cmp = dyn_cast<ICmpInst>(Y)) if (Value *Res = FoldAndOfICmps(LHS, Cmp)) - return ReplaceInstUsesWith(I, Builder->CreateAnd(Res, X)); + return replaceInstUsesWith(I, Builder->CreateAnd(Res, X)); } if (RHS && match(Op0, m_OneUse(m_And(m_Value(X), m_Value(Y))))) { if (auto *Cmp = dyn_cast<ICmpInst>(X)) if (Value *Res = FoldAndOfICmps(Cmp, RHS)) - return ReplaceInstUsesWith(I, Builder->CreateAnd(Res, Y)); + return replaceInstUsesWith(I, Builder->CreateAnd(Res, Y)); if (auto *Cmp = dyn_cast<ICmpInst>(Y)) if (Value *Res = FoldAndOfICmps(Cmp, RHS)) - return ReplaceInstUsesWith(I, Builder->CreateAnd(Res, X)); + return replaceInstUsesWith(I, Builder->CreateAnd(Res, X)); } } @@ -1485,92 +1540,46 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) if (Value *Res = FoldAndOfFCmps(LHS, RHS)) - return ReplaceInstUsesWith(I, Res); - - - if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) { - Value *Op0COp = Op0C->getOperand(0); - Type *SrcTy = Op0COp->getType(); - // fold (and (cast A), (cast B)) -> (cast (and A, B)) - if (CastInst *Op1C = dyn_cast<CastInst>(Op1)) { - if (Op0C->getOpcode() == Op1C->getOpcode() && // same cast kind ? - SrcTy == Op1C->getOperand(0)->getType() && - SrcTy->isIntOrIntVectorTy()) { - Value *Op1COp = Op1C->getOperand(0); - - // Only do this if the casts both really cause code to be generated. - if (ShouldOptimizeCast(Op0C->getOpcode(), Op0COp, I.getType()) && - ShouldOptimizeCast(Op1C->getOpcode(), Op1COp, I.getType())) { - Value *NewOp = Builder->CreateAnd(Op0COp, Op1COp, I.getName()); - return CastInst::Create(Op0C->getOpcode(), NewOp, I.getType()); - } + return replaceInstUsesWith(I, Res); - // If this is and(cast(icmp), cast(icmp)), try to fold this even if the - // cast is otherwise not optimizable. This happens for vector sexts. - if (ICmpInst *RHS = dyn_cast<ICmpInst>(Op1COp)) - if (ICmpInst *LHS = dyn_cast<ICmpInst>(Op0COp)) - if (Value *Res = FoldAndOfICmps(LHS, RHS)) - return CastInst::Create(Op0C->getOpcode(), Res, I.getType()); - - // If this is and(cast(fcmp), cast(fcmp)), try to fold this even if the - // cast is otherwise not optimizable. This happens for vector sexts. - if (FCmpInst *RHS = dyn_cast<FCmpInst>(Op1COp)) - if (FCmpInst *LHS = dyn_cast<FCmpInst>(Op0COp)) - if (Value *Res = FoldAndOfFCmps(LHS, RHS)) - return CastInst::Create(Op0C->getOpcode(), Res, I.getType()); - } - } + if (Instruction *CastedAnd = foldCastedBitwiseLogic(I)) + return CastedAnd; - // If we are masking off the sign bit of a floating-point value, convert - // this to the canonical fabs intrinsic call and cast back to integer. - // The backend should know how to optimize fabs(). - // TODO: This transform should also apply to vectors. - ConstantInt *CI; - if (isa<BitCastInst>(Op0C) && SrcTy->isFloatingPointTy() && - match(Op1, m_ConstantInt(CI)) && CI->isMaxValue(true)) { - Module *M = I.getModule(); - Function *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, SrcTy); - Value *Call = Builder->CreateCall(Fabs, Op0COp, "fabs"); - return CastInst::CreateBitOrPointerCast(Call, I.getType()); - } - } + if (Instruction *Select = foldBoolSextMaskToSelect(I)) + return Select; - { - Value *X = nullptr; - bool OpsSwapped = false; - // Canonicalize SExt or Not to the LHS - if (match(Op1, m_SExt(m_Value())) || - match(Op1, m_Not(m_Value()))) { - std::swap(Op0, Op1); - OpsSwapped = true; - } + return Changed ? &I : nullptr; +} - // Fold (and (sext bool to A), B) --> (select bool, B, 0) - if (match(Op0, m_SExt(m_Value(X))) && - X->getType()->getScalarType()->isIntegerTy(1)) { - Value *Zero = Constant::getNullValue(Op1->getType()); - return SelectInst::Create(X, Op1, Zero); - } +/// Given an OR instruction, check to see if this is a bswap idiom. If so, +/// insert the new intrinsic and return it. +Instruction *InstCombiner::MatchBSwap(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - // Fold (and ~(sext bool to A), B) --> (select bool, 0, B) - if (match(Op0, m_Not(m_SExt(m_Value(X)))) && - X->getType()->getScalarType()->isIntegerTy(1)) { - Value *Zero = Constant::getNullValue(Op0->getType()); - return SelectInst::Create(X, Zero, Op1); - } + // Look through zero extends. + if (Instruction *Ext = dyn_cast<ZExtInst>(Op0)) + Op0 = Ext->getOperand(0); - if (OpsSwapped) - std::swap(Op0, Op1); - } + if (Instruction *Ext = dyn_cast<ZExtInst>(Op1)) + Op1 = Ext->getOperand(0); - return Changed ? &I : nullptr; -} + // (A | B) | C and A | (B | C) -> bswap if possible. + bool OrOfOrs = match(Op0, m_Or(m_Value(), m_Value())) || + match(Op1, m_Or(m_Value(), m_Value())); + + // (A >> B) | (C << D) and (A << B) | (B >> C) -> bswap if possible. + bool OrOfShifts = match(Op0, m_LogicalShift(m_Value(), m_Value())) && + match(Op1, m_LogicalShift(m_Value(), m_Value())); + + // (A & B) | (C & D) -> bswap if possible. + bool OrOfAnds = match(Op0, m_And(m_Value(), m_Value())) && + match(Op1, m_And(m_Value(), m_Value())); + + if (!OrOfOrs && !OrOfShifts && !OrOfAnds) + return nullptr; -/// Given an OR instruction, check to see if this is a bswap or bitreverse -/// idiom. If so, insert the new intrinsic and return it. -Instruction *InstCombiner::MatchBSwapOrBitReverse(BinaryOperator &I) { SmallVector<Instruction*, 4> Insts; - if (!recognizeBitReverseOrBSwapIdiom(&I, true, false, Insts)) + if (!recognizeBSwapOrBitReverseIdiom(&I, true, false, Insts)) return nullptr; Instruction *LastInst = Insts.pop_back_val(); LastInst->removeFromParent(); @@ -1580,28 +1589,89 @@ Instruction *InstCombiner::MatchBSwapOrBitReverse(BinaryOperator &I) { return LastInst; } -/// We have an expression of the form (A&C)|(B&D). Check if A is (cond?-1:0) -/// and either B or D is ~(cond?-1,0) or (cond?0,-1), then we can simplify this -/// expression to "cond ? C : D or B". -static Instruction *MatchSelectFromAndOr(Value *A, Value *B, - Value *C, Value *D) { - // If A is not a select of -1/0, this cannot match. - Value *Cond = nullptr; - if (!match(A, m_SExt(m_Value(Cond))) || - !Cond->getType()->isIntegerTy(1)) +/// If all elements of two constant vectors are 0/-1 and inverses, return true. +static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) { + unsigned NumElts = C1->getType()->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *EltC1 = C1->getAggregateElement(i); + Constant *EltC2 = C2->getAggregateElement(i); + if (!EltC1 || !EltC2) + return false; + + // One element must be all ones, and the other must be all zeros. + // FIXME: Allow undef elements. + if (!((match(EltC1, m_Zero()) && match(EltC2, m_AllOnes())) || + (match(EltC2, m_Zero()) && match(EltC1, m_AllOnes())))) + return false; + } + return true; +} + +/// We have an expression of the form (A & C) | (B & D). If A is a scalar or +/// vector composed of all-zeros or all-ones values and is the bitwise 'not' of +/// B, it can be used as the condition operand of a select instruction. +static Value *getSelectCondition(Value *A, Value *B, + InstCombiner::BuilderTy &Builder) { + // If these are scalars or vectors of i1, A can be used directly. + Type *Ty = A->getType(); + if (match(A, m_Not(m_Specific(B))) && Ty->getScalarType()->isIntegerTy(1)) + return A; + + // If A and B are sign-extended, look through the sexts to find the booleans. + Value *Cond; + if (match(A, m_SExt(m_Value(Cond))) && + Cond->getType()->getScalarType()->isIntegerTy(1) && + match(B, m_CombineOr(m_Not(m_SExt(m_Specific(Cond))), + m_SExt(m_Not(m_Specific(Cond)))))) + return Cond; + + // All scalar (and most vector) possibilities should be handled now. + // Try more matches that only apply to non-splat constant vectors. + if (!Ty->isVectorTy()) return nullptr; - // ((cond?-1:0)&C) | (B&(cond?0:-1)) -> cond ? C : B. - if (match(D, m_Not(m_SExt(m_Specific(Cond))))) - return SelectInst::Create(Cond, C, B); - if (match(D, m_SExt(m_Not(m_Specific(Cond))))) - return SelectInst::Create(Cond, C, B); - - // ((cond?-1:0)&C) | ((cond?0:-1)&D) -> cond ? C : D. - if (match(B, m_Not(m_SExt(m_Specific(Cond))))) - return SelectInst::Create(Cond, C, D); - if (match(B, m_SExt(m_Not(m_Specific(Cond))))) - return SelectInst::Create(Cond, C, D); + // If both operands are constants, see if the constants are inverse bitmasks. + Constant *AC, *BC; + if (match(A, m_Constant(AC)) && match(B, m_Constant(BC)) && + areInverseVectorBitmasks(AC, BC)) + return ConstantExpr::getTrunc(AC, CmpInst::makeCmpResultType(Ty)); + + // If both operands are xor'd with constants using the same sexted boolean + // operand, see if the constants are inverse bitmasks. + if (match(A, (m_Xor(m_SExt(m_Value(Cond)), m_Constant(AC)))) && + match(B, (m_Xor(m_SExt(m_Specific(Cond)), m_Constant(BC)))) && + Cond->getType()->getScalarType()->isIntegerTy(1) && + areInverseVectorBitmasks(AC, BC)) { + AC = ConstantExpr::getTrunc(AC, CmpInst::makeCmpResultType(Ty)); + return Builder.CreateXor(Cond, AC); + } + return nullptr; +} + +/// We have an expression of the form (A & C) | (B & D). Try to simplify this +/// to "A' ? C : D", where A' is a boolean or vector of booleans. +static Value *matchSelectFromAndOr(Value *A, Value *C, Value *B, Value *D, + InstCombiner::BuilderTy &Builder) { + // The potential condition of the select may be bitcasted. In that case, look + // through its bitcast and the corresponding bitcast of the 'not' condition. + Type *OrigType = A->getType(); + Value *SrcA, *SrcB; + if (match(A, m_OneUse(m_BitCast(m_Value(SrcA)))) && + match(B, m_OneUse(m_BitCast(m_Value(SrcB))))) { + A = SrcA; + B = SrcB; + } + + if (Value *Cond = getSelectCondition(A, B, Builder)) { + // ((bc Cond) & C) | ((bc ~Cond) & D) --> bc (select Cond, (bc C), (bc D)) + // The bitcasts will either all exist or all not exist. The builder will + // not create unnecessary casts if the types already match. + Value *BitcastC = Builder.CreateBitCast(C, A->getType()); + Value *BitcastD = Builder.CreateBitCast(D, A->getType()); + Value *Select = Builder.CreateSelect(Cond, BitcastC, BitcastD); + return Builder.CreateBitCast(Select, OrigType); + } + return nullptr; } @@ -1940,6 +2010,27 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, /// Optimize (fcmp)|(fcmp). NOTE: Unlike the rest of instcombine, this returns /// a Value which should already be inserted into the function. Value *InstCombiner::FoldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { + Value *Op0LHS = LHS->getOperand(0), *Op0RHS = LHS->getOperand(1); + Value *Op1LHS = RHS->getOperand(0), *Op1RHS = RHS->getOperand(1); + FCmpInst::Predicate Op0CC = LHS->getPredicate(), Op1CC = RHS->getPredicate(); + + if (Op0LHS == Op1RHS && Op0RHS == Op1LHS) { + // Swap RHS operands to match LHS. + Op1CC = FCmpInst::getSwappedPredicate(Op1CC); + std::swap(Op1LHS, Op1RHS); + } + + // Simplify (fcmp cc0 x, y) | (fcmp cc1 x, y). + // This is a similar transformation to the one in FoldAndOfFCmps. + // + // Since (R & CC0) and (R & CC1) are either R or 0, we actually have this: + // bool(R & CC0) || bool(R & CC1) + // = bool((R & CC0) | (R & CC1)) + // = bool(R & (CC0 | CC1)) <= by reversed distribution (contribution? ;) + if (Op0LHS == Op1LHS && Op0RHS == Op1RHS) + return getFCmpValue(getFCmpCode(Op0CC) | getFCmpCode(Op1CC), Op0LHS, Op0RHS, + Builder); + if (LHS->getPredicate() == FCmpInst::FCMP_UNO && RHS->getPredicate() == FCmpInst::FCMP_UNO && LHS->getOperand(0)->getType() == RHS->getOperand(0)->getType()) { @@ -1964,35 +2055,6 @@ Value *InstCombiner::FoldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { return nullptr; } - Value *Op0LHS = LHS->getOperand(0), *Op0RHS = LHS->getOperand(1); - Value *Op1LHS = RHS->getOperand(0), *Op1RHS = RHS->getOperand(1); - FCmpInst::Predicate Op0CC = LHS->getPredicate(), Op1CC = RHS->getPredicate(); - - if (Op0LHS == Op1RHS && Op0RHS == Op1LHS) { - // Swap RHS operands to match LHS. - Op1CC = FCmpInst::getSwappedPredicate(Op1CC); - std::swap(Op1LHS, Op1RHS); - } - if (Op0LHS == Op1LHS && Op0RHS == Op1RHS) { - // Simplify (fcmp cc0 x, y) | (fcmp cc1 x, y). - if (Op0CC == Op1CC) - return Builder->CreateFCmp((FCmpInst::Predicate)Op0CC, Op0LHS, Op0RHS); - if (Op0CC == FCmpInst::FCMP_TRUE || Op1CC == FCmpInst::FCMP_TRUE) - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 1); - if (Op0CC == FCmpInst::FCMP_FALSE) - return RHS; - if (Op1CC == FCmpInst::FCMP_FALSE) - return LHS; - bool Op0Ordered; - bool Op1Ordered; - unsigned Op0Pred = getFCmpCode(Op0CC, Op0Ordered); - unsigned Op1Pred = getFCmpCode(Op1CC, Op1Ordered); - if (Op0Ordered == Op1Ordered) { - // If both are ordered or unordered, return a new fcmp with - // or'ed predicates. - return getFCmpValue(Op0Ordered, Op0Pred|Op1Pred, Op0LHS, Op0RHS, Builder); - } - } return nullptr; } @@ -2062,14 +2124,14 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyOrInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // (A&B)|(A&C) -> A&(B|C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. @@ -2077,7 +2139,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { return &I; if (Value *V = SimplifyBSwap(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { ConstantInt *C1 = nullptr; Value *X = nullptr; @@ -2111,23 +2173,13 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { return NV; } + // Given an OR instruction, check to see if this is a bswap. + if (Instruction *BSwap = MatchBSwap(I)) + return BSwap; + Value *A = nullptr, *B = nullptr; ConstantInt *C1 = nullptr, *C2 = nullptr; - // (A | B) | C and A | (B | C) -> bswap if possible. - bool OrOfOrs = match(Op0, m_Or(m_Value(), m_Value())) || - match(Op1, m_Or(m_Value(), m_Value())); - // (A >> B) | (C << D) and (A << B) | (B >> C) -> bswap if possible. - bool OrOfShifts = match(Op0, m_LogicalShift(m_Value(), m_Value())) && - match(Op1, m_LogicalShift(m_Value(), m_Value())); - // (A & B) | (C & D) -> bswap if possible. - bool OrOfAnds = match(Op0, m_And(m_Value(), m_Value())) && - match(Op1, m_And(m_Value(), m_Value())); - - if (OrOfOrs || OrOfShifts || OrOfAnds) - if (Instruction *BSwap = MatchBSwapOrBitReverse(I)) - return BSwap; - // (X^C)|Y -> (X|Y)^C iff Y&C == 0 if (Op0->hasOneUse() && match(Op0, m_Xor(m_Value(A), m_ConstantInt(C1))) && @@ -2207,18 +2259,27 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { } } - // (A & (C0?-1:0)) | (B & ~(C0?-1:0)) -> C0 ? A : B, and commuted variants. - // Don't do this for vector select idioms, the code generator doesn't handle - // them well yet. - if (!I.getType()->isVectorTy()) { - if (Instruction *Match = MatchSelectFromAndOr(A, B, C, D)) - return Match; - if (Instruction *Match = MatchSelectFromAndOr(B, A, D, C)) - return Match; - if (Instruction *Match = MatchSelectFromAndOr(C, B, A, D)) - return Match; - if (Instruction *Match = MatchSelectFromAndOr(D, A, B, C)) - return Match; + // Don't try to form a select if it's unlikely that we'll get rid of at + // least one of the operands. A select is generally more expensive than the + // 'or' that it is replacing. + if (Op0->hasOneUse() || Op1->hasOneUse()) { + // (Cond & C) | (~Cond & D) -> Cond ? C : D, and commuted variants. + if (Value *V = matchSelectFromAndOr(A, C, B, D, *Builder)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(A, C, D, B, *Builder)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(C, A, B, D, *Builder)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(C, A, D, B, *Builder)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(B, D, A, C, *Builder)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(B, D, C, A, *Builder)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(D, B, A, C, *Builder)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(D, B, C, A, *Builder)) + return replaceInstUsesWith(I, V); } // ((A&~B)|(~A&B)) -> A^B @@ -2342,7 +2403,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { ICmpInst *RHS = dyn_cast<ICmpInst>(Op1); if (LHS && RHS) if (Value *Res = FoldOrOfICmps(LHS, RHS, &I)) - return ReplaceInstUsesWith(I, Res); + return replaceInstUsesWith(I, Res); // TODO: Make this recursive; it's a little tricky because an arbitrary // number of 'or' instructions might have to be created. @@ -2350,18 +2411,18 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (LHS && match(Op1, m_OneUse(m_Or(m_Value(X), m_Value(Y))))) { if (auto *Cmp = dyn_cast<ICmpInst>(X)) if (Value *Res = FoldOrOfICmps(LHS, Cmp, &I)) - return ReplaceInstUsesWith(I, Builder->CreateOr(Res, Y)); + return replaceInstUsesWith(I, Builder->CreateOr(Res, Y)); if (auto *Cmp = dyn_cast<ICmpInst>(Y)) if (Value *Res = FoldOrOfICmps(LHS, Cmp, &I)) - return ReplaceInstUsesWith(I, Builder->CreateOr(Res, X)); + return replaceInstUsesWith(I, Builder->CreateOr(Res, X)); } if (RHS && match(Op0, m_OneUse(m_Or(m_Value(X), m_Value(Y))))) { if (auto *Cmp = dyn_cast<ICmpInst>(X)) if (Value *Res = FoldOrOfICmps(Cmp, RHS, &I)) - return ReplaceInstUsesWith(I, Builder->CreateOr(Res, Y)); + return replaceInstUsesWith(I, Builder->CreateOr(Res, Y)); if (auto *Cmp = dyn_cast<ICmpInst>(Y)) if (Value *Res = FoldOrOfICmps(Cmp, RHS, &I)) - return ReplaceInstUsesWith(I, Builder->CreateOr(Res, X)); + return replaceInstUsesWith(I, Builder->CreateOr(Res, X)); } } @@ -2369,48 +2430,17 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) if (Value *Res = FoldOrOfFCmps(LHS, RHS)) - return ReplaceInstUsesWith(I, Res); - - // fold (or (cast A), (cast B)) -> (cast (or A, B)) - if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) { - CastInst *Op1C = dyn_cast<CastInst>(Op1); - if (Op1C && Op0C->getOpcode() == Op1C->getOpcode()) {// same cast kind ? - Type *SrcTy = Op0C->getOperand(0)->getType(); - if (SrcTy == Op1C->getOperand(0)->getType() && - SrcTy->isIntOrIntVectorTy()) { - Value *Op0COp = Op0C->getOperand(0), *Op1COp = Op1C->getOperand(0); - - if ((!isa<ICmpInst>(Op0COp) || !isa<ICmpInst>(Op1COp)) && - // Only do this if the casts both really cause code to be - // generated. - ShouldOptimizeCast(Op0C->getOpcode(), Op0COp, I.getType()) && - ShouldOptimizeCast(Op1C->getOpcode(), Op1COp, I.getType())) { - Value *NewOp = Builder->CreateOr(Op0COp, Op1COp, I.getName()); - return CastInst::Create(Op0C->getOpcode(), NewOp, I.getType()); - } + return replaceInstUsesWith(I, Res); - // If this is or(cast(icmp), cast(icmp)), try to fold this even if the - // cast is otherwise not optimizable. This happens for vector sexts. - if (ICmpInst *RHS = dyn_cast<ICmpInst>(Op1COp)) - if (ICmpInst *LHS = dyn_cast<ICmpInst>(Op0COp)) - if (Value *Res = FoldOrOfICmps(LHS, RHS, &I)) - return CastInst::Create(Op0C->getOpcode(), Res, I.getType()); - - // If this is or(cast(fcmp), cast(fcmp)), try to fold this even if the - // cast is otherwise not optimizable. This happens for vector sexts. - if (FCmpInst *RHS = dyn_cast<FCmpInst>(Op1COp)) - if (FCmpInst *LHS = dyn_cast<FCmpInst>(Op0COp)) - if (Value *Res = FoldOrOfFCmps(LHS, RHS)) - return CastInst::Create(Op0C->getOpcode(), Res, I.getType()); - } - } - } + if (Instruction *CastedOr = foldCastedBitwiseLogic(I)) + return CastedOr; - // or(sext(A), B) -> A ? -1 : B where A is an i1 - // or(A, sext(B)) -> B ? -1 : A where B is an i1 - if (match(Op0, m_SExt(m_Value(A))) && A->getType()->isIntegerTy(1)) + // or(sext(A), B) / or(B, sext(A)) --> A ? -1 : B, where A is i1 or <N x i1>. + if (match(Op0, m_OneUse(m_SExt(m_Value(A)))) && + A->getType()->getScalarType()->isIntegerTy(1)) return SelectInst::Create(A, ConstantInt::getSigned(I.getType(), -1), Op1); - if (match(Op1, m_SExt(m_Value(A))) && A->getType()->isIntegerTy(1)) + if (match(Op1, m_OneUse(m_SExt(m_Value(A)))) && + A->getType()->getScalarType()->isIntegerTy(1)) return SelectInst::Create(A, ConstantInt::getSigned(I.getType(), -1), Op0); // Note: If we've gotten to the point of visiting the outer OR, then the @@ -2447,14 +2477,14 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyXorInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // (A&B)^(A&C) -> A&(B^C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. @@ -2462,7 +2492,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { return &I; if (Value *V = SimplifyBSwap(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // Is this a ~ operation? if (Value *NotOp = dyn_castNotVal(&I)) { @@ -2731,29 +2761,14 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { Value *Op0 = LHS->getOperand(0), *Op1 = LHS->getOperand(1); unsigned Code = getICmpCode(LHS) ^ getICmpCode(RHS); bool isSigned = LHS->isSigned() || RHS->isSigned(); - return ReplaceInstUsesWith(I, + return replaceInstUsesWith(I, getNewICmpValue(isSigned, Code, Op0, Op1, Builder)); } } - // fold (xor (cast A), (cast B)) -> (cast (xor A, B)) - if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) { - if (CastInst *Op1C = dyn_cast<CastInst>(Op1)) - if (Op0C->getOpcode() == Op1C->getOpcode()) { // same cast kind? - Type *SrcTy = Op0C->getOperand(0)->getType(); - if (SrcTy == Op1C->getOperand(0)->getType() && SrcTy->isIntegerTy() && - // Only do this if the casts both really cause code to be generated. - ShouldOptimizeCast(Op0C->getOpcode(), Op0C->getOperand(0), - I.getType()) && - ShouldOptimizeCast(Op1C->getOpcode(), Op1C->getOperand(0), - I.getType())) { - Value *NewOp = Builder->CreateXor(Op0C->getOperand(0), - Op1C->getOperand(0), I.getName()); - return CastInst::Create(Op0C->getOpcode(), NewOp, I.getType()); - } - } - } + if (Instruction *CastedXor = foldCastedBitwiseLogic(I)) + return CastedXor; return Changed ? &I : nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp index 090245d1b22c6..8acff91345d61 100644 --- a/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -14,6 +14,7 @@ #include "InstCombineInternal.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Dominators.h" @@ -29,8 +30,8 @@ using namespace PatternMatch; STATISTIC(NumSimplified, "Number of library calls simplified"); -/// getPromotedType - Return the specified type promoted as it would be to pass -/// though a va_arg area. +/// Return the specified type promoted as it would be to pass though a va_arg +/// area. static Type *getPromotedType(Type *Ty) { if (IntegerType* ITy = dyn_cast<IntegerType>(Ty)) { if (ITy->getBitWidth() < 32) @@ -39,8 +40,8 @@ static Type *getPromotedType(Type *Ty) { return Ty; } -/// reduceToSingleValueType - Given an aggregate type which ultimately holds a -/// single scalar element, like {{{type}}} or [1 x type], return type. +/// Given an aggregate type which ultimately holds a single scalar element, +/// like {{{type}}} or [1 x type], return type. static Type *reduceToSingleValueType(Type *T) { while (!T->isSingleValueType()) { if (StructType *STy = dyn_cast<StructType>(T)) { @@ -60,6 +61,23 @@ static Type *reduceToSingleValueType(Type *T) { return T; } +/// Return a constant boolean vector that has true elements in all positions +/// where the input constant data vector has an element with the sign bit set. +static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) { + SmallVector<Constant *, 32> BoolVec; + IntegerType *BoolTy = Type::getInt1Ty(V->getContext()); + for (unsigned I = 0, E = V->getNumElements(); I != E; ++I) { + Constant *Elt = V->getElementAsConstant(I); + assert((isa<ConstantInt>(Elt) || isa<ConstantFP>(Elt)) && + "Unexpected constant data vector element type"); + bool Sign = V->getElementType()->isIntegerTy() + ? cast<ConstantInt>(Elt)->isNegative() + : cast<ConstantFP>(Elt)->isNegative(); + BoolVec.push_back(ConstantInt::get(BoolTy, Sign)); + } + return ConstantVector::get(BoolVec); +} + Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL, MI, AC, DT); unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL, MI, AC, DT); @@ -197,7 +215,7 @@ Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) { return nullptr; } -static Value *SimplifyX86immshift(const IntrinsicInst &II, +static Value *simplifyX86immShift(const IntrinsicInst &II, InstCombiner::BuilderTy &Builder) { bool LogicalShift = false; bool ShiftLeft = false; @@ -307,83 +325,216 @@ static Value *SimplifyX86immshift(const IntrinsicInst &II, return Builder.CreateAShr(Vec, ShiftVec); } -static Value *SimplifyX86extend(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder, - bool SignExtend) { - VectorType *SrcTy = cast<VectorType>(II.getArgOperand(0)->getType()); - VectorType *DstTy = cast<VectorType>(II.getType()); - unsigned NumDstElts = DstTy->getNumElements(); - - // Extract a subvector of the first NumDstElts lanes and sign/zero extend. - SmallVector<int, 8> ShuffleMask; - for (int i = 0; i != (int)NumDstElts; ++i) - ShuffleMask.push_back(i); - - Value *SV = Builder.CreateShuffleVector(II.getArgOperand(0), - UndefValue::get(SrcTy), ShuffleMask); - return SignExtend ? Builder.CreateSExt(SV, DstTy) - : Builder.CreateZExt(SV, DstTy); -} - -static Value *SimplifyX86insertps(const IntrinsicInst &II, +// Attempt to simplify AVX2 per-element shift intrinsics to a generic IR shift. +// Unlike the generic IR shifts, the intrinsics have defined behaviour for out +// of range shift amounts (logical - set to zero, arithmetic - splat sign bit). +static Value *simplifyX86varShift(const IntrinsicInst &II, InstCombiner::BuilderTy &Builder) { - if (auto *CInt = dyn_cast<ConstantInt>(II.getArgOperand(2))) { - VectorType *VecTy = cast<VectorType>(II.getType()); - assert(VecTy->getNumElements() == 4 && "insertps with wrong vector type"); - - // The immediate permute control byte looks like this: - // [3:0] - zero mask for each 32-bit lane - // [5:4] - select one 32-bit destination lane - // [7:6] - select one 32-bit source lane - - uint8_t Imm = CInt->getZExtValue(); - uint8_t ZMask = Imm & 0xf; - uint8_t DestLane = (Imm >> 4) & 0x3; - uint8_t SourceLane = (Imm >> 6) & 0x3; - - ConstantAggregateZero *ZeroVector = ConstantAggregateZero::get(VecTy); - - // If all zero mask bits are set, this was just a weird way to - // generate a zero vector. - if (ZMask == 0xf) - return ZeroVector; - - // Initialize by passing all of the first source bits through. - int ShuffleMask[4] = { 0, 1, 2, 3 }; - - // We may replace the second operand with the zero vector. - Value *V1 = II.getArgOperand(1); - - if (ZMask) { - // If the zero mask is being used with a single input or the zero mask - // overrides the destination lane, this is a shuffle with the zero vector. - if ((II.getArgOperand(0) == II.getArgOperand(1)) || - (ZMask & (1 << DestLane))) { - V1 = ZeroVector; - // We may still move 32-bits of the first source vector from one lane - // to another. - ShuffleMask[DestLane] = SourceLane; - // The zero mask may override the previous insert operation. - for (unsigned i = 0; i < 4; ++i) - if ((ZMask >> i) & 0x1) - ShuffleMask[i] = i + 4; + bool LogicalShift = false; + bool ShiftLeft = false; + + switch (II.getIntrinsicID()) { + default: + return nullptr; + case Intrinsic::x86_avx2_psrav_d: + case Intrinsic::x86_avx2_psrav_d_256: + LogicalShift = false; + ShiftLeft = false; + break; + case Intrinsic::x86_avx2_psrlv_d: + case Intrinsic::x86_avx2_psrlv_d_256: + case Intrinsic::x86_avx2_psrlv_q: + case Intrinsic::x86_avx2_psrlv_q_256: + LogicalShift = true; + ShiftLeft = false; + break; + case Intrinsic::x86_avx2_psllv_d: + case Intrinsic::x86_avx2_psllv_d_256: + case Intrinsic::x86_avx2_psllv_q: + case Intrinsic::x86_avx2_psllv_q_256: + LogicalShift = true; + ShiftLeft = true; + break; + } + assert((LogicalShift || !ShiftLeft) && "Only logical shifts can shift left"); + + // Simplify if all shift amounts are constant/undef. + auto *CShift = dyn_cast<Constant>(II.getArgOperand(1)); + if (!CShift) + return nullptr; + + auto Vec = II.getArgOperand(0); + auto VT = cast<VectorType>(II.getType()); + auto SVT = VT->getVectorElementType(); + int NumElts = VT->getNumElements(); + int BitWidth = SVT->getIntegerBitWidth(); + + // Collect each element's shift amount. + // We also collect special cases: UNDEF = -1, OUT-OF-RANGE = BitWidth. + bool AnyOutOfRange = false; + SmallVector<int, 8> ShiftAmts; + for (int I = 0; I < NumElts; ++I) { + auto *CElt = CShift->getAggregateElement(I); + if (CElt && isa<UndefValue>(CElt)) { + ShiftAmts.push_back(-1); + continue; + } + + auto *COp = dyn_cast_or_null<ConstantInt>(CElt); + if (!COp) + return nullptr; + + // Handle out of range shifts. + // If LogicalShift - set to BitWidth (special case). + // If ArithmeticShift - set to (BitWidth - 1) (sign splat). + APInt ShiftVal = COp->getValue(); + if (ShiftVal.uge(BitWidth)) { + AnyOutOfRange = LogicalShift; + ShiftAmts.push_back(LogicalShift ? BitWidth : BitWidth - 1); + continue; + } + + ShiftAmts.push_back((int)ShiftVal.getZExtValue()); + } + + // If all elements out of range or UNDEF, return vector of zeros/undefs. + // ArithmeticShift should only hit this if they are all UNDEF. + auto OutOfRange = [&](int Idx) { return (Idx < 0) || (BitWidth <= Idx); }; + if (llvm::all_of(ShiftAmts, OutOfRange)) { + SmallVector<Constant *, 8> ConstantVec; + for (int Idx : ShiftAmts) { + if (Idx < 0) { + ConstantVec.push_back(UndefValue::get(SVT)); } else { - // TODO: Model this case as 2 shuffles or a 'logical and' plus shuffle? - return nullptr; + assert(LogicalShift && "Logical shift expected"); + ConstantVec.push_back(ConstantInt::getNullValue(SVT)); } - } else { - // Replace the selected destination lane with the selected source lane. - ShuffleMask[DestLane] = SourceLane + 4; } + return ConstantVector::get(ConstantVec); + } - return Builder.CreateShuffleVector(II.getArgOperand(0), V1, ShuffleMask); + // We can't handle only some out of range values with generic logical shifts. + if (AnyOutOfRange) + return nullptr; + + // Build the shift amount constant vector. + SmallVector<Constant *, 8> ShiftVecAmts; + for (int Idx : ShiftAmts) { + if (Idx < 0) + ShiftVecAmts.push_back(UndefValue::get(SVT)); + else + ShiftVecAmts.push_back(ConstantInt::get(SVT, Idx)); } - return nullptr; + auto ShiftVec = ConstantVector::get(ShiftVecAmts); + + if (ShiftLeft) + return Builder.CreateShl(Vec, ShiftVec); + + if (LogicalShift) + return Builder.CreateLShr(Vec, ShiftVec); + + return Builder.CreateAShr(Vec, ShiftVec); +} + +static Value *simplifyX86movmsk(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + Value *Arg = II.getArgOperand(0); + Type *ResTy = II.getType(); + Type *ArgTy = Arg->getType(); + + // movmsk(undef) -> zero as we must ensure the upper bits are zero. + if (isa<UndefValue>(Arg)) + return Constant::getNullValue(ResTy); + + // We can't easily peek through x86_mmx types. + if (!ArgTy->isVectorTy()) + return nullptr; + + auto *C = dyn_cast<Constant>(Arg); + if (!C) + return nullptr; + + // Extract signbits of the vector input and pack into integer result. + APInt Result(ResTy->getPrimitiveSizeInBits(), 0); + for (unsigned I = 0, E = ArgTy->getVectorNumElements(); I != E; ++I) { + auto *COp = C->getAggregateElement(I); + if (!COp) + return nullptr; + if (isa<UndefValue>(COp)) + continue; + + auto *CInt = dyn_cast<ConstantInt>(COp); + auto *CFp = dyn_cast<ConstantFP>(COp); + if (!CInt && !CFp) + return nullptr; + + if ((CInt && CInt->isNegative()) || (CFp && CFp->isNegative())) + Result.setBit(I); + } + + return Constant::getIntegerValue(ResTy, Result); +} + +static Value *simplifyX86insertps(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + auto *CInt = dyn_cast<ConstantInt>(II.getArgOperand(2)); + if (!CInt) + return nullptr; + + VectorType *VecTy = cast<VectorType>(II.getType()); + assert(VecTy->getNumElements() == 4 && "insertps with wrong vector type"); + + // The immediate permute control byte looks like this: + // [3:0] - zero mask for each 32-bit lane + // [5:4] - select one 32-bit destination lane + // [7:6] - select one 32-bit source lane + + uint8_t Imm = CInt->getZExtValue(); + uint8_t ZMask = Imm & 0xf; + uint8_t DestLane = (Imm >> 4) & 0x3; + uint8_t SourceLane = (Imm >> 6) & 0x3; + + ConstantAggregateZero *ZeroVector = ConstantAggregateZero::get(VecTy); + + // If all zero mask bits are set, this was just a weird way to + // generate a zero vector. + if (ZMask == 0xf) + return ZeroVector; + + // Initialize by passing all of the first source bits through. + uint32_t ShuffleMask[4] = { 0, 1, 2, 3 }; + + // We may replace the second operand with the zero vector. + Value *V1 = II.getArgOperand(1); + + if (ZMask) { + // If the zero mask is being used with a single input or the zero mask + // overrides the destination lane, this is a shuffle with the zero vector. + if ((II.getArgOperand(0) == II.getArgOperand(1)) || + (ZMask & (1 << DestLane))) { + V1 = ZeroVector; + // We may still move 32-bits of the first source vector from one lane + // to another. + ShuffleMask[DestLane] = SourceLane; + // The zero mask may override the previous insert operation. + for (unsigned i = 0; i < 4; ++i) + if ((ZMask >> i) & 0x1) + ShuffleMask[i] = i + 4; + } else { + // TODO: Model this case as 2 shuffles or a 'logical and' plus shuffle? + return nullptr; + } + } else { + // Replace the selected destination lane with the selected source lane. + ShuffleMask[DestLane] = SourceLane + 4; + } + + return Builder.CreateShuffleVector(II.getArgOperand(0), V1, ShuffleMask); } /// Attempt to simplify SSE4A EXTRQ/EXTRQI instructions using constant folding /// or conversion to a shuffle vector. -static Value *SimplifyX86extrq(IntrinsicInst &II, Value *Op0, +static Value *simplifyX86extrq(IntrinsicInst &II, Value *Op0, ConstantInt *CILength, ConstantInt *CIIndex, InstCombiner::BuilderTy &Builder) { auto LowConstantHighUndef = [&](uint64_t Val) { @@ -476,7 +627,7 @@ static Value *SimplifyX86extrq(IntrinsicInst &II, Value *Op0, /// Attempt to simplify SSE4A INSERTQ/INSERTQI instructions using constant /// folding or conversion to a shuffle vector. -static Value *SimplifyX86insertq(IntrinsicInst &II, Value *Op0, Value *Op1, +static Value *simplifyX86insertq(IntrinsicInst &II, Value *Op0, Value *Op1, APInt APLength, APInt APIndex, InstCombiner::BuilderTy &Builder) { @@ -571,74 +722,211 @@ static Value *SimplifyX86insertq(IntrinsicInst &II, Value *Op0, Value *Op1, return nullptr; } -/// The shuffle mask for a perm2*128 selects any two halves of two 256-bit -/// source vectors, unless a zero bit is set. If a zero bit is set, -/// then ignore that half of the mask and clear that half of the vector. -static Value *SimplifyX86vperm2(const IntrinsicInst &II, +/// Attempt to convert pshufb* to shufflevector if the mask is constant. +static Value *simplifyX86pshufb(const IntrinsicInst &II, InstCombiner::BuilderTy &Builder) { - if (auto *CInt = dyn_cast<ConstantInt>(II.getArgOperand(2))) { - VectorType *VecTy = cast<VectorType>(II.getType()); - ConstantAggregateZero *ZeroVector = ConstantAggregateZero::get(VecTy); + Constant *V = dyn_cast<Constant>(II.getArgOperand(1)); + if (!V) + return nullptr; + + auto *VecTy = cast<VectorType>(II.getType()); + auto *MaskEltTy = Type::getInt32Ty(II.getContext()); + unsigned NumElts = VecTy->getNumElements(); + assert((NumElts == 16 || NumElts == 32) && + "Unexpected number of elements in shuffle mask!"); + + // Construct a shuffle mask from constant integers or UNDEFs. + Constant *Indexes[32] = {NULL}; + + // Each byte in the shuffle control mask forms an index to permute the + // corresponding byte in the destination operand. + for (unsigned I = 0; I < NumElts; ++I) { + Constant *COp = V->getAggregateElement(I); + if (!COp || (!isa<UndefValue>(COp) && !isa<ConstantInt>(COp))) + return nullptr; + + if (isa<UndefValue>(COp)) { + Indexes[I] = UndefValue::get(MaskEltTy); + continue; + } - // The immediate permute control byte looks like this: - // [1:0] - select 128 bits from sources for low half of destination - // [2] - ignore - // [3] - zero low half of destination - // [5:4] - select 128 bits from sources for high half of destination - // [6] - ignore - // [7] - zero high half of destination + int8_t Index = cast<ConstantInt>(COp)->getValue().getZExtValue(); - uint8_t Imm = CInt->getZExtValue(); + // If the most significant bit (bit[7]) of each byte of the shuffle + // control mask is set, then zero is written in the result byte. + // The zero vector is in the right-hand side of the resulting + // shufflevector. - bool LowHalfZero = Imm & 0x08; - bool HighHalfZero = Imm & 0x80; + // The value of each index for the high 128-bit lane is the least + // significant 4 bits of the respective shuffle control byte. + Index = ((Index < 0) ? NumElts : Index & 0x0F) + (I & 0xF0); + Indexes[I] = ConstantInt::get(MaskEltTy, Index); + } - // If both zero mask bits are set, this was just a weird way to - // generate a zero vector. - if (LowHalfZero && HighHalfZero) - return ZeroVector; + auto ShuffleMask = ConstantVector::get(makeArrayRef(Indexes, NumElts)); + auto V1 = II.getArgOperand(0); + auto V2 = Constant::getNullValue(VecTy); + return Builder.CreateShuffleVector(V1, V2, ShuffleMask); +} - // If 0 or 1 zero mask bits are set, this is a simple shuffle. - unsigned NumElts = VecTy->getNumElements(); - unsigned HalfSize = NumElts / 2; - SmallVector<int, 8> ShuffleMask(NumElts); +/// Attempt to convert vpermilvar* to shufflevector if the mask is constant. +static Value *simplifyX86vpermilvar(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + Constant *V = dyn_cast<Constant>(II.getArgOperand(1)); + if (!V) + return nullptr; + + auto *MaskEltTy = Type::getInt32Ty(II.getContext()); + unsigned NumElts = cast<VectorType>(V->getType())->getNumElements(); + assert(NumElts == 8 || NumElts == 4 || NumElts == 2); - // The high bit of the selection field chooses the 1st or 2nd operand. - bool LowInputSelect = Imm & 0x02; - bool HighInputSelect = Imm & 0x20; + // Construct a shuffle mask from constant integers or UNDEFs. + Constant *Indexes[8] = {NULL}; - // The low bit of the selection field chooses the low or high half - // of the selected operand. - bool LowHalfSelect = Imm & 0x01; - bool HighHalfSelect = Imm & 0x10; + // The intrinsics only read one or two bits, clear the rest. + for (unsigned I = 0; I < NumElts; ++I) { + Constant *COp = V->getAggregateElement(I); + if (!COp || (!isa<UndefValue>(COp) && !isa<ConstantInt>(COp))) + return nullptr; - // Determine which operand(s) are actually in use for this instruction. - Value *V0 = LowInputSelect ? II.getArgOperand(1) : II.getArgOperand(0); - Value *V1 = HighInputSelect ? II.getArgOperand(1) : II.getArgOperand(0); + if (isa<UndefValue>(COp)) { + Indexes[I] = UndefValue::get(MaskEltTy); + continue; + } - // If needed, replace operands based on zero mask. - V0 = LowHalfZero ? ZeroVector : V0; - V1 = HighHalfZero ? ZeroVector : V1; + APInt Index = cast<ConstantInt>(COp)->getValue(); + Index = Index.zextOrTrunc(32).getLoBits(2); - // Permute low half of result. - unsigned StartIndex = LowHalfSelect ? HalfSize : 0; - for (unsigned i = 0; i < HalfSize; ++i) - ShuffleMask[i] = StartIndex + i; + // The PD variants uses bit 1 to select per-lane element index, so + // shift down to convert to generic shuffle mask index. + if (II.getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_pd || + II.getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_pd_256) + Index = Index.lshr(1); - // Permute high half of result. - StartIndex = HighHalfSelect ? HalfSize : 0; - StartIndex += NumElts; - for (unsigned i = 0; i < HalfSize; ++i) - ShuffleMask[i + HalfSize] = StartIndex + i; + // The _256 variants are a bit trickier since the mask bits always index + // into the corresponding 128 half. In order to convert to a generic + // shuffle, we have to make that explicit. + if ((II.getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_ps_256 || + II.getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_pd_256) && + ((NumElts / 2) <= I)) { + Index += APInt(32, NumElts / 2); + } - return Builder.CreateShuffleVector(V0, V1, ShuffleMask); + Indexes[I] = ConstantInt::get(MaskEltTy, Index); } - return nullptr; + + auto ShuffleMask = ConstantVector::get(makeArrayRef(Indexes, NumElts)); + auto V1 = II.getArgOperand(0); + auto V2 = UndefValue::get(V1->getType()); + return Builder.CreateShuffleVector(V1, V2, ShuffleMask); +} + +/// Attempt to convert vpermd/vpermps to shufflevector if the mask is constant. +static Value *simplifyX86vpermv(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + auto *V = dyn_cast<Constant>(II.getArgOperand(1)); + if (!V) + return nullptr; + + auto *VecTy = cast<VectorType>(II.getType()); + auto *MaskEltTy = Type::getInt32Ty(II.getContext()); + unsigned Size = VecTy->getNumElements(); + assert(Size == 8 && "Unexpected shuffle mask size"); + + // Construct a shuffle mask from constant integers or UNDEFs. + Constant *Indexes[8] = {NULL}; + + for (unsigned I = 0; I < Size; ++I) { + Constant *COp = V->getAggregateElement(I); + if (!COp || (!isa<UndefValue>(COp) && !isa<ConstantInt>(COp))) + return nullptr; + + if (isa<UndefValue>(COp)) { + Indexes[I] = UndefValue::get(MaskEltTy); + continue; + } + + APInt Index = cast<ConstantInt>(COp)->getValue(); + Index = Index.zextOrTrunc(32).getLoBits(3); + Indexes[I] = ConstantInt::get(MaskEltTy, Index); + } + + auto ShuffleMask = ConstantVector::get(makeArrayRef(Indexes, Size)); + auto V1 = II.getArgOperand(0); + auto V2 = UndefValue::get(VecTy); + return Builder.CreateShuffleVector(V1, V2, ShuffleMask); +} + +/// The shuffle mask for a perm2*128 selects any two halves of two 256-bit +/// source vectors, unless a zero bit is set. If a zero bit is set, +/// then ignore that half of the mask and clear that half of the vector. +static Value *simplifyX86vperm2(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + auto *CInt = dyn_cast<ConstantInt>(II.getArgOperand(2)); + if (!CInt) + return nullptr; + + VectorType *VecTy = cast<VectorType>(II.getType()); + ConstantAggregateZero *ZeroVector = ConstantAggregateZero::get(VecTy); + + // The immediate permute control byte looks like this: + // [1:0] - select 128 bits from sources for low half of destination + // [2] - ignore + // [3] - zero low half of destination + // [5:4] - select 128 bits from sources for high half of destination + // [6] - ignore + // [7] - zero high half of destination + + uint8_t Imm = CInt->getZExtValue(); + + bool LowHalfZero = Imm & 0x08; + bool HighHalfZero = Imm & 0x80; + + // If both zero mask bits are set, this was just a weird way to + // generate a zero vector. + if (LowHalfZero && HighHalfZero) + return ZeroVector; + + // If 0 or 1 zero mask bits are set, this is a simple shuffle. + unsigned NumElts = VecTy->getNumElements(); + unsigned HalfSize = NumElts / 2; + SmallVector<uint32_t, 8> ShuffleMask(NumElts); + + // The high bit of the selection field chooses the 1st or 2nd operand. + bool LowInputSelect = Imm & 0x02; + bool HighInputSelect = Imm & 0x20; + + // The low bit of the selection field chooses the low or high half + // of the selected operand. + bool LowHalfSelect = Imm & 0x01; + bool HighHalfSelect = Imm & 0x10; + + // Determine which operand(s) are actually in use for this instruction. + Value *V0 = LowInputSelect ? II.getArgOperand(1) : II.getArgOperand(0); + Value *V1 = HighInputSelect ? II.getArgOperand(1) : II.getArgOperand(0); + + // If needed, replace operands based on zero mask. + V0 = LowHalfZero ? ZeroVector : V0; + V1 = HighHalfZero ? ZeroVector : V1; + + // Permute low half of result. + unsigned StartIndex = LowHalfSelect ? HalfSize : 0; + for (unsigned i = 0; i < HalfSize; ++i) + ShuffleMask[i] = StartIndex + i; + + // Permute high half of result. + StartIndex = HighHalfSelect ? HalfSize : 0; + StartIndex += NumElts; + for (unsigned i = 0; i < HalfSize; ++i) + ShuffleMask[i + HalfSize] = StartIndex + i; + + return Builder.CreateShuffleVector(V0, V1, ShuffleMask); } /// Decode XOP integer vector comparison intrinsics. -static Value *SimplifyX86vpcom(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder, bool IsSigned) { +static Value *simplifyX86vpcom(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder, + bool IsSigned) { if (auto *CInt = dyn_cast<ConstantInt>(II.getArgOperand(2))) { uint64_t Imm = CInt->getZExtValue() & 0x7; VectorType *VecTy = cast<VectorType>(II.getType()); @@ -667,21 +955,296 @@ static Value *SimplifyX86vpcom(const IntrinsicInst &II, return ConstantInt::getSigned(VecTy, -1); // TRUE } - if (Value *Cmp = Builder.CreateICmp(Pred, II.getArgOperand(0), II.getArgOperand(1))) + if (Value *Cmp = Builder.CreateICmp(Pred, II.getArgOperand(0), + II.getArgOperand(1))) return Builder.CreateSExtOrTrunc(Cmp, VecTy); } return nullptr; } -/// visitCallInst - CallInst simplification. This mostly only handles folding -/// of intrinsic instructions. For normal calls, it allows visitCallSite to do -/// the heavy lifting. -/// +static Value *simplifyMinnumMaxnum(const IntrinsicInst &II) { + Value *Arg0 = II.getArgOperand(0); + Value *Arg1 = II.getArgOperand(1); + + // fmin(x, x) -> x + if (Arg0 == Arg1) + return Arg0; + + const auto *C1 = dyn_cast<ConstantFP>(Arg1); + + // fmin(x, nan) -> x + if (C1 && C1->isNaN()) + return Arg0; + + // This is the value because if undef were NaN, we would return the other + // value and cannot return a NaN unless both operands are. + // + // fmin(undef, x) -> x + if (isa<UndefValue>(Arg0)) + return Arg1; + + // fmin(x, undef) -> x + if (isa<UndefValue>(Arg1)) + return Arg0; + + Value *X = nullptr; + Value *Y = nullptr; + if (II.getIntrinsicID() == Intrinsic::minnum) { + // fmin(x, fmin(x, y)) -> fmin(x, y) + // fmin(y, fmin(x, y)) -> fmin(x, y) + if (match(Arg1, m_FMin(m_Value(X), m_Value(Y)))) { + if (Arg0 == X || Arg0 == Y) + return Arg1; + } + + // fmin(fmin(x, y), x) -> fmin(x, y) + // fmin(fmin(x, y), y) -> fmin(x, y) + if (match(Arg0, m_FMin(m_Value(X), m_Value(Y)))) { + if (Arg1 == X || Arg1 == Y) + return Arg0; + } + + // TODO: fmin(nnan x, inf) -> x + // TODO: fmin(nnan ninf x, flt_max) -> x + if (C1 && C1->isInfinity()) { + // fmin(x, -inf) -> -inf + if (C1->isNegative()) + return Arg1; + } + } else { + assert(II.getIntrinsicID() == Intrinsic::maxnum); + // fmax(x, fmax(x, y)) -> fmax(x, y) + // fmax(y, fmax(x, y)) -> fmax(x, y) + if (match(Arg1, m_FMax(m_Value(X), m_Value(Y)))) { + if (Arg0 == X || Arg0 == Y) + return Arg1; + } + + // fmax(fmax(x, y), x) -> fmax(x, y) + // fmax(fmax(x, y), y) -> fmax(x, y) + if (match(Arg0, m_FMax(m_Value(X), m_Value(Y)))) { + if (Arg1 == X || Arg1 == Y) + return Arg0; + } + + // TODO: fmax(nnan x, -inf) -> x + // TODO: fmax(nnan ninf x, -flt_max) -> x + if (C1 && C1->isInfinity()) { + // fmax(x, inf) -> inf + if (!C1->isNegative()) + return Arg1; + } + } + return nullptr; +} + +static bool maskIsAllOneOrUndef(Value *Mask) { + auto *ConstMask = dyn_cast<Constant>(Mask); + if (!ConstMask) + return false; + if (ConstMask->isAllOnesValue() || isa<UndefValue>(ConstMask)) + return true; + for (unsigned I = 0, E = ConstMask->getType()->getVectorNumElements(); I != E; + ++I) { + if (auto *MaskElt = ConstMask->getAggregateElement(I)) + if (MaskElt->isAllOnesValue() || isa<UndefValue>(MaskElt)) + continue; + return false; + } + return true; +} + +static Value *simplifyMaskedLoad(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + // If the mask is all ones or undefs, this is a plain vector load of the 1st + // argument. + if (maskIsAllOneOrUndef(II.getArgOperand(2))) { + Value *LoadPtr = II.getArgOperand(0); + unsigned Alignment = cast<ConstantInt>(II.getArgOperand(1))->getZExtValue(); + return Builder.CreateAlignedLoad(LoadPtr, Alignment, "unmaskedload"); + } + + return nullptr; +} + +static Instruction *simplifyMaskedStore(IntrinsicInst &II, InstCombiner &IC) { + auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); + if (!ConstMask) + return nullptr; + + // If the mask is all zeros, this instruction does nothing. + if (ConstMask->isNullValue()) + return IC.eraseInstFromFunction(II); + + // If the mask is all ones, this is a plain vector store of the 1st argument. + if (ConstMask->isAllOnesValue()) { + Value *StorePtr = II.getArgOperand(1); + unsigned Alignment = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue(); + return new StoreInst(II.getArgOperand(0), StorePtr, false, Alignment); + } + + return nullptr; +} + +static Instruction *simplifyMaskedGather(IntrinsicInst &II, InstCombiner &IC) { + // If the mask is all zeros, return the "passthru" argument of the gather. + auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(2)); + if (ConstMask && ConstMask->isNullValue()) + return IC.replaceInstUsesWith(II, II.getArgOperand(3)); + + return nullptr; +} + +static Instruction *simplifyMaskedScatter(IntrinsicInst &II, InstCombiner &IC) { + // If the mask is all zeros, a scatter does nothing. + auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); + if (ConstMask && ConstMask->isNullValue()) + return IC.eraseInstFromFunction(II); + + return nullptr; +} + +// TODO: If the x86 backend knew how to convert a bool vector mask back to an +// XMM register mask efficiently, we could transform all x86 masked intrinsics +// to LLVM masked intrinsics and remove the x86 masked intrinsic defs. +static Instruction *simplifyX86MaskedLoad(IntrinsicInst &II, InstCombiner &IC) { + Value *Ptr = II.getOperand(0); + Value *Mask = II.getOperand(1); + Constant *ZeroVec = Constant::getNullValue(II.getType()); + + // Special case a zero mask since that's not a ConstantDataVector. + // This masked load instruction creates a zero vector. + if (isa<ConstantAggregateZero>(Mask)) + return IC.replaceInstUsesWith(II, ZeroVec); + + auto *ConstMask = dyn_cast<ConstantDataVector>(Mask); + if (!ConstMask) + return nullptr; + + // The mask is constant. Convert this x86 intrinsic to the LLVM instrinsic + // to allow target-independent optimizations. + + // First, cast the x86 intrinsic scalar pointer to a vector pointer to match + // the LLVM intrinsic definition for the pointer argument. + unsigned AddrSpace = cast<PointerType>(Ptr->getType())->getAddressSpace(); + PointerType *VecPtrTy = PointerType::get(II.getType(), AddrSpace); + Value *PtrCast = IC.Builder->CreateBitCast(Ptr, VecPtrTy, "castvec"); + + // Second, convert the x86 XMM integer vector mask to a vector of bools based + // on each element's most significant bit (the sign bit). + Constant *BoolMask = getNegativeIsTrueBoolVec(ConstMask); + + // The pass-through vector for an x86 masked load is a zero vector. + CallInst *NewMaskedLoad = + IC.Builder->CreateMaskedLoad(PtrCast, 1, BoolMask, ZeroVec); + return IC.replaceInstUsesWith(II, NewMaskedLoad); +} + +// TODO: If the x86 backend knew how to convert a bool vector mask back to an +// XMM register mask efficiently, we could transform all x86 masked intrinsics +// to LLVM masked intrinsics and remove the x86 masked intrinsic defs. +static bool simplifyX86MaskedStore(IntrinsicInst &II, InstCombiner &IC) { + Value *Ptr = II.getOperand(0); + Value *Mask = II.getOperand(1); + Value *Vec = II.getOperand(2); + + // Special case a zero mask since that's not a ConstantDataVector: + // this masked store instruction does nothing. + if (isa<ConstantAggregateZero>(Mask)) { + IC.eraseInstFromFunction(II); + return true; + } + + // The SSE2 version is too weird (eg, unaligned but non-temporal) to do + // anything else at this level. + if (II.getIntrinsicID() == Intrinsic::x86_sse2_maskmov_dqu) + return false; + + auto *ConstMask = dyn_cast<ConstantDataVector>(Mask); + if (!ConstMask) + return false; + + // The mask is constant. Convert this x86 intrinsic to the LLVM instrinsic + // to allow target-independent optimizations. + + // First, cast the x86 intrinsic scalar pointer to a vector pointer to match + // the LLVM intrinsic definition for the pointer argument. + unsigned AddrSpace = cast<PointerType>(Ptr->getType())->getAddressSpace(); + PointerType *VecPtrTy = PointerType::get(Vec->getType(), AddrSpace); + Value *PtrCast = IC.Builder->CreateBitCast(Ptr, VecPtrTy, "castvec"); + + // Second, convert the x86 XMM integer vector mask to a vector of bools based + // on each element's most significant bit (the sign bit). + Constant *BoolMask = getNegativeIsTrueBoolVec(ConstMask); + + IC.Builder->CreateMaskedStore(Vec, PtrCast, 1, BoolMask); + + // 'Replace uses' doesn't work for stores. Erase the original masked store. + IC.eraseInstFromFunction(II); + return true; +} + +// Returns true iff the 2 intrinsics have the same operands, limiting the +// comparison to the first NumOperands. +static bool haveSameOperands(const IntrinsicInst &I, const IntrinsicInst &E, + unsigned NumOperands) { + assert(I.getNumArgOperands() >= NumOperands && "Not enough operands"); + assert(E.getNumArgOperands() >= NumOperands && "Not enough operands"); + for (unsigned i = 0; i < NumOperands; i++) + if (I.getArgOperand(i) != E.getArgOperand(i)) + return false; + return true; +} + +// Remove trivially empty start/end intrinsic ranges, i.e. a start +// immediately followed by an end (ignoring debuginfo or other +// start/end intrinsics in between). As this handles only the most trivial +// cases, tracking the nesting level is not needed: +// +// call @llvm.foo.start(i1 0) ; &I +// call @llvm.foo.start(i1 0) +// call @llvm.foo.end(i1 0) ; This one will not be skipped: it will be removed +// call @llvm.foo.end(i1 0) +static bool removeTriviallyEmptyRange(IntrinsicInst &I, unsigned StartID, + unsigned EndID, InstCombiner &IC) { + assert(I.getIntrinsicID() == StartID && + "Start intrinsic does not have expected ID"); + BasicBlock::iterator BI(I), BE(I.getParent()->end()); + for (++BI; BI != BE; ++BI) { + if (auto *E = dyn_cast<IntrinsicInst>(BI)) { + if (isa<DbgInfoIntrinsic>(E) || E->getIntrinsicID() == StartID) + continue; + if (E->getIntrinsicID() == EndID && + haveSameOperands(I, *E, E->getNumArgOperands())) { + IC.eraseInstFromFunction(*E); + IC.eraseInstFromFunction(I); + return true; + } + } + break; + } + + return false; +} + +Instruction *InstCombiner::visitVAStartInst(VAStartInst &I) { + removeTriviallyEmptyRange(I, Intrinsic::vastart, Intrinsic::vaend, *this); + return nullptr; +} + +Instruction *InstCombiner::visitVACopyInst(VACopyInst &I) { + removeTriviallyEmptyRange(I, Intrinsic::vacopy, Intrinsic::vaend, *this); + return nullptr; +} + +/// CallInst simplification. This mostly only handles folding of intrinsic +/// instructions. For normal calls, it allows visitCallSite to do the heavy +/// lifting. Instruction *InstCombiner::visitCallInst(CallInst &CI) { auto Args = CI.arg_operands(); if (Value *V = SimplifyCall(CI.getCalledValue(), Args.begin(), Args.end(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(CI, V); + return replaceInstUsesWith(CI, V); if (isFreeCall(&CI, TLI)) return visitFree(CI); @@ -705,7 +1268,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // memmove/cpy/set of zero bytes is a noop. if (Constant *NumBytes = dyn_cast<Constant>(MI->getLength())) { if (NumBytes->isNullValue()) - return EraseInstFromFunction(CI); + return eraseInstFromFunction(CI); if (ConstantInt *CI = dyn_cast<ConstantInt>(NumBytes)) if (CI->getZExtValue() == 1) { @@ -738,7 +1301,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) { // memmove(x,x,size) -> noop. if (MTI->getSource() == MTI->getDest()) - return EraseInstFromFunction(CI); + return eraseInstFromFunction(CI); } // If we can determine a pointer alignment that is bigger than currently @@ -754,19 +1317,30 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (Changed) return II; } - auto SimplifyDemandedVectorEltsLow = [this](Value *Op, unsigned Width, unsigned DemandedWidth) - { + auto SimplifyDemandedVectorEltsLow = [this](Value *Op, unsigned Width, + unsigned DemandedWidth) { APInt UndefElts(Width, 0); APInt DemandedElts = APInt::getLowBitsSet(Width, DemandedWidth); return SimplifyDemandedVectorElts(Op, DemandedElts, UndefElts); }; + auto SimplifyDemandedVectorEltsHigh = [this](Value *Op, unsigned Width, + unsigned DemandedWidth) { + APInt UndefElts(Width, 0); + APInt DemandedElts = APInt::getHighBitsSet(Width, DemandedWidth); + return SimplifyDemandedVectorElts(Op, DemandedElts, UndefElts); + }; switch (II->getIntrinsicID()) { default: break; case Intrinsic::objectsize: { uint64_t Size; - if (getObjectSize(II->getArgOperand(0), Size, DL, TLI)) - return ReplaceInstUsesWith(CI, ConstantInt::get(CI.getType(), Size)); + if (getObjectSize(II->getArgOperand(0), Size, DL, TLI)) { + APInt APSize(II->getType()->getIntegerBitWidth(), Size); + // Equality check to be sure that `Size` can fit in a value of type + // `II->getType()` + if (APSize == Size) + return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), APSize)); + } return nullptr; } case Intrinsic::bswap: { @@ -775,7 +1349,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // bswap(bswap(x)) -> x if (match(IIOperand, m_BSwap(m_Value(X)))) - return ReplaceInstUsesWith(CI, X); + return replaceInstUsesWith(CI, X); // bswap(trunc(bswap(x))) -> trunc(lshr(x, c)) if (match(IIOperand, m_Trunc(m_BSwap(m_Value(X))))) { @@ -794,18 +1368,29 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // bitreverse(bitreverse(x)) -> x if (match(IIOperand, m_Intrinsic<Intrinsic::bitreverse>(m_Value(X)))) - return ReplaceInstUsesWith(CI, X); + return replaceInstUsesWith(CI, X); break; } + case Intrinsic::masked_load: + if (Value *SimplifiedMaskedOp = simplifyMaskedLoad(*II, *Builder)) + return replaceInstUsesWith(CI, SimplifiedMaskedOp); + break; + case Intrinsic::masked_store: + return simplifyMaskedStore(*II, *this); + case Intrinsic::masked_gather: + return simplifyMaskedGather(*II, *this); + case Intrinsic::masked_scatter: + return simplifyMaskedScatter(*II, *this); + case Intrinsic::powi: if (ConstantInt *Power = dyn_cast<ConstantInt>(II->getArgOperand(1))) { // powi(x, 0) -> 1.0 if (Power->isZero()) - return ReplaceInstUsesWith(CI, ConstantFP::get(CI.getType(), 1.0)); + return replaceInstUsesWith(CI, ConstantFP::get(CI.getType(), 1.0)); // powi(x, 1) -> x if (Power->isOne()) - return ReplaceInstUsesWith(CI, II->getArgOperand(0)); + return replaceInstUsesWith(CI, II->getArgOperand(0)); // powi(x, -1) -> 1/x if (Power->isAllOnesValue()) return BinaryOperator::CreateFDiv(ConstantFP::get(CI.getType(), 1.0), @@ -825,7 +1410,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { unsigned TrailingZeros = KnownOne.countTrailingZeros(); APInt Mask(APInt::getLowBitsSet(BitWidth, TrailingZeros)); if ((Mask & KnownZero) == Mask) - return ReplaceInstUsesWith(CI, ConstantInt::get(IT, + return replaceInstUsesWith(CI, ConstantInt::get(IT, APInt(BitWidth, TrailingZeros))); } @@ -843,7 +1428,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { unsigned LeadingZeros = KnownOne.countLeadingZeros(); APInt Mask(APInt::getHighBitsSet(BitWidth, LeadingZeros)); if ((Mask & KnownZero) == Mask) - return ReplaceInstUsesWith(CI, ConstantInt::get(IT, + return replaceInstUsesWith(CI, ConstantInt::get(IT, APInt(BitWidth, LeadingZeros))); } @@ -882,84 +1467,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::maxnum: { Value *Arg0 = II->getArgOperand(0); Value *Arg1 = II->getArgOperand(1); - - // fmin(x, x) -> x - if (Arg0 == Arg1) - return ReplaceInstUsesWith(CI, Arg0); - - const ConstantFP *C0 = dyn_cast<ConstantFP>(Arg0); - const ConstantFP *C1 = dyn_cast<ConstantFP>(Arg1); - - // Canonicalize constants into the RHS. - if (C0 && !C1) { + // Canonicalize constants to the RHS. + if (isa<ConstantFP>(Arg0) && !isa<ConstantFP>(Arg1)) { II->setArgOperand(0, Arg1); II->setArgOperand(1, Arg0); return II; } - - // fmin(x, nan) -> x - if (C1 && C1->isNaN()) - return ReplaceInstUsesWith(CI, Arg0); - - // This is the value because if undef were NaN, we would return the other - // value and cannot return a NaN unless both operands are. - // - // fmin(undef, x) -> x - if (isa<UndefValue>(Arg0)) - return ReplaceInstUsesWith(CI, Arg1); - - // fmin(x, undef) -> x - if (isa<UndefValue>(Arg1)) - return ReplaceInstUsesWith(CI, Arg0); - - Value *X = nullptr; - Value *Y = nullptr; - if (II->getIntrinsicID() == Intrinsic::minnum) { - // fmin(x, fmin(x, y)) -> fmin(x, y) - // fmin(y, fmin(x, y)) -> fmin(x, y) - if (match(Arg1, m_FMin(m_Value(X), m_Value(Y)))) { - if (Arg0 == X || Arg0 == Y) - return ReplaceInstUsesWith(CI, Arg1); - } - - // fmin(fmin(x, y), x) -> fmin(x, y) - // fmin(fmin(x, y), y) -> fmin(x, y) - if (match(Arg0, m_FMin(m_Value(X), m_Value(Y)))) { - if (Arg1 == X || Arg1 == Y) - return ReplaceInstUsesWith(CI, Arg0); - } - - // TODO: fmin(nnan x, inf) -> x - // TODO: fmin(nnan ninf x, flt_max) -> x - if (C1 && C1->isInfinity()) { - // fmin(x, -inf) -> -inf - if (C1->isNegative()) - return ReplaceInstUsesWith(CI, Arg1); - } - } else { - assert(II->getIntrinsicID() == Intrinsic::maxnum); - // fmax(x, fmax(x, y)) -> fmax(x, y) - // fmax(y, fmax(x, y)) -> fmax(x, y) - if (match(Arg1, m_FMax(m_Value(X), m_Value(Y)))) { - if (Arg0 == X || Arg0 == Y) - return ReplaceInstUsesWith(CI, Arg1); - } - - // fmax(fmax(x, y), x) -> fmax(x, y) - // fmax(fmax(x, y), y) -> fmax(x, y) - if (match(Arg0, m_FMax(m_Value(X), m_Value(Y)))) { - if (Arg1 == X || Arg1 == Y) - return ReplaceInstUsesWith(CI, Arg0); - } - - // TODO: fmax(nnan x, -inf) -> x - // TODO: fmax(nnan ninf x, -flt_max) -> x - if (C1 && C1->isInfinity()) { - // fmax(x, inf) -> inf - if (!C1->isNegative()) - return ReplaceInstUsesWith(CI, Arg1); - } - } + if (Value *V = simplifyMinnumMaxnum(*II)) + return replaceInstUsesWith(*II, V); break; } case Intrinsic::ppc_altivec_lvx: @@ -1041,19 +1556,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } break; - case Intrinsic::x86_sse_storeu_ps: - case Intrinsic::x86_sse2_storeu_pd: - case Intrinsic::x86_sse2_storeu_dq: - // Turn X86 storeu -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL, II, AC, DT) >= - 16) { - Type *OpPtrTy = - PointerType::getUnqual(II->getArgOperand(1)->getType()); - Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), OpPtrTy); - return new StoreInst(II->getArgOperand(1), Ptr); - } - break; - case Intrinsic::x86_vcvtph2ps_128: case Intrinsic::x86_vcvtph2ps_256: { auto Arg = II->getArgOperand(0); @@ -1070,12 +1572,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Constant folding: Convert to generic half to single conversion. if (isa<ConstantAggregateZero>(Arg)) - return ReplaceInstUsesWith(*II, ConstantAggregateZero::get(RetType)); + return replaceInstUsesWith(*II, ConstantAggregateZero::get(RetType)); if (isa<ConstantDataVector>(Arg)) { auto VectorHalfAsShorts = Arg; if (RetWidth < ArgWidth) { - SmallVector<int, 8> SubVecMask; + SmallVector<uint32_t, 8> SubVecMask; for (unsigned i = 0; i != RetWidth; ++i) SubVecMask.push_back((int)i); VectorHalfAsShorts = Builder->CreateShuffleVector( @@ -1087,7 +1589,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { auto VectorHalfs = Builder->CreateBitCast(VectorHalfAsShorts, VectorHalfType); auto VectorFloats = Builder->CreateFPExt(VectorHalfs, RetType); - return ReplaceInstUsesWith(*II, VectorFloats); + return replaceInstUsesWith(*II, VectorFloats); } // We only use the lowest lanes of the argument. @@ -1117,6 +1619,107 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::x86_mmx_pmovmskb: + case Intrinsic::x86_sse_movmsk_ps: + case Intrinsic::x86_sse2_movmsk_pd: + case Intrinsic::x86_sse2_pmovmskb_128: + case Intrinsic::x86_avx_movmsk_pd_256: + case Intrinsic::x86_avx_movmsk_ps_256: + case Intrinsic::x86_avx2_pmovmskb: { + if (Value *V = simplifyX86movmsk(*II, *Builder)) + return replaceInstUsesWith(*II, V); + break; + } + + case Intrinsic::x86_sse_comieq_ss: + case Intrinsic::x86_sse_comige_ss: + case Intrinsic::x86_sse_comigt_ss: + case Intrinsic::x86_sse_comile_ss: + case Intrinsic::x86_sse_comilt_ss: + case Intrinsic::x86_sse_comineq_ss: + case Intrinsic::x86_sse_ucomieq_ss: + case Intrinsic::x86_sse_ucomige_ss: + case Intrinsic::x86_sse_ucomigt_ss: + case Intrinsic::x86_sse_ucomile_ss: + case Intrinsic::x86_sse_ucomilt_ss: + case Intrinsic::x86_sse_ucomineq_ss: + case Intrinsic::x86_sse2_comieq_sd: + case Intrinsic::x86_sse2_comige_sd: + case Intrinsic::x86_sse2_comigt_sd: + case Intrinsic::x86_sse2_comile_sd: + case Intrinsic::x86_sse2_comilt_sd: + case Intrinsic::x86_sse2_comineq_sd: + case Intrinsic::x86_sse2_ucomieq_sd: + case Intrinsic::x86_sse2_ucomige_sd: + case Intrinsic::x86_sse2_ucomigt_sd: + case Intrinsic::x86_sse2_ucomile_sd: + case Intrinsic::x86_sse2_ucomilt_sd: + case Intrinsic::x86_sse2_ucomineq_sd: { + // These intrinsics only demand the 0th element of their input vectors. If + // we can simplify the input based on that, do so now. + bool MadeChange = false; + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + unsigned VWidth = Arg0->getType()->getVectorNumElements(); + if (Value *V = SimplifyDemandedVectorEltsLow(Arg0, VWidth, 1)) { + II->setArgOperand(0, V); + MadeChange = true; + } + if (Value *V = SimplifyDemandedVectorEltsLow(Arg1, VWidth, 1)) { + II->setArgOperand(1, V); + MadeChange = true; + } + if (MadeChange) + return II; + break; + } + + case Intrinsic::x86_sse_add_ss: + case Intrinsic::x86_sse_sub_ss: + case Intrinsic::x86_sse_mul_ss: + case Intrinsic::x86_sse_div_ss: + case Intrinsic::x86_sse_min_ss: + case Intrinsic::x86_sse_max_ss: + case Intrinsic::x86_sse_cmp_ss: + case Intrinsic::x86_sse2_add_sd: + case Intrinsic::x86_sse2_sub_sd: + case Intrinsic::x86_sse2_mul_sd: + case Intrinsic::x86_sse2_div_sd: + case Intrinsic::x86_sse2_min_sd: + case Intrinsic::x86_sse2_max_sd: + case Intrinsic::x86_sse2_cmp_sd: { + // These intrinsics only demand the lowest element of the second input + // vector. + Value *Arg1 = II->getArgOperand(1); + unsigned VWidth = Arg1->getType()->getVectorNumElements(); + if (Value *V = SimplifyDemandedVectorEltsLow(Arg1, VWidth, 1)) { + II->setArgOperand(1, V); + return II; + } + break; + } + + case Intrinsic::x86_sse41_round_ss: + case Intrinsic::x86_sse41_round_sd: { + // These intrinsics demand the upper elements of the first input vector and + // the lowest element of the second input vector. + bool MadeChange = false; + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + unsigned VWidth = Arg0->getType()->getVectorNumElements(); + if (Value *V = SimplifyDemandedVectorEltsHigh(Arg0, VWidth, VWidth - 1)) { + II->setArgOperand(0, V); + MadeChange = true; + } + if (Value *V = SimplifyDemandedVectorEltsLow(Arg1, VWidth, 1)) { + II->setArgOperand(1, V); + MadeChange = true; + } + if (MadeChange) + return II; + break; + } + // Constant fold ashr( <A x Bi>, Ci ). // Constant fold lshr( <A x Bi>, Ci ). // Constant fold shl( <A x Bi>, Ci ). @@ -1136,8 +1739,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx2_pslli_d: case Intrinsic::x86_avx2_pslli_q: case Intrinsic::x86_avx2_pslli_w: - if (Value *V = SimplifyX86immshift(*II, *Builder)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86immShift(*II, *Builder)) + return replaceInstUsesWith(*II, V); break; case Intrinsic::x86_sse2_psra_d: @@ -1156,8 +1759,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx2_psll_d: case Intrinsic::x86_avx2_psll_q: case Intrinsic::x86_avx2_psll_w: { - if (Value *V = SimplifyX86immshift(*II, *Builder)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86immShift(*II, *Builder)) + return replaceInstUsesWith(*II, V); // SSE2/AVX2 uses only the first 64-bits of the 128-bit vector // operand to compute the shift amount. @@ -1173,35 +1776,23 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } - case Intrinsic::x86_avx2_pmovsxbd: - case Intrinsic::x86_avx2_pmovsxbq: - case Intrinsic::x86_avx2_pmovsxbw: - case Intrinsic::x86_avx2_pmovsxdq: - case Intrinsic::x86_avx2_pmovsxwd: - case Intrinsic::x86_avx2_pmovsxwq: - if (Value *V = SimplifyX86extend(*II, *Builder, true)) - return ReplaceInstUsesWith(*II, V); - break; - - case Intrinsic::x86_sse41_pmovzxbd: - case Intrinsic::x86_sse41_pmovzxbq: - case Intrinsic::x86_sse41_pmovzxbw: - case Intrinsic::x86_sse41_pmovzxdq: - case Intrinsic::x86_sse41_pmovzxwd: - case Intrinsic::x86_sse41_pmovzxwq: - case Intrinsic::x86_avx2_pmovzxbd: - case Intrinsic::x86_avx2_pmovzxbq: - case Intrinsic::x86_avx2_pmovzxbw: - case Intrinsic::x86_avx2_pmovzxdq: - case Intrinsic::x86_avx2_pmovzxwd: - case Intrinsic::x86_avx2_pmovzxwq: - if (Value *V = SimplifyX86extend(*II, *Builder, false)) - return ReplaceInstUsesWith(*II, V); + case Intrinsic::x86_avx2_psllv_d: + case Intrinsic::x86_avx2_psllv_d_256: + case Intrinsic::x86_avx2_psllv_q: + case Intrinsic::x86_avx2_psllv_q_256: + case Intrinsic::x86_avx2_psrav_d: + case Intrinsic::x86_avx2_psrav_d_256: + case Intrinsic::x86_avx2_psrlv_d: + case Intrinsic::x86_avx2_psrlv_d_256: + case Intrinsic::x86_avx2_psrlv_q: + case Intrinsic::x86_avx2_psrlv_q_256: + if (Value *V = simplifyX86varShift(*II, *Builder)) + return replaceInstUsesWith(*II, V); break; case Intrinsic::x86_sse41_insertps: - if (Value *V = SimplifyX86insertps(*II, *Builder)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86insertps(*II, *Builder)) + return replaceInstUsesWith(*II, V); break; case Intrinsic::x86_sse4a_extrq: { @@ -1223,19 +1814,22 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { : nullptr; // Attempt to simplify to a constant, shuffle vector or EXTRQI call. - if (Value *V = SimplifyX86extrq(*II, Op0, CILength, CIIndex, *Builder)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86extrq(*II, Op0, CILength, CIIndex, *Builder)) + return replaceInstUsesWith(*II, V); // EXTRQ only uses the lowest 64-bits of the first 128-bit vector // operands and the lowest 16-bits of the second. + bool MadeChange = false; if (Value *V = SimplifyDemandedVectorEltsLow(Op0, VWidth0, 1)) { II->setArgOperand(0, V); - return II; + MadeChange = true; } if (Value *V = SimplifyDemandedVectorEltsLow(Op1, VWidth1, 2)) { II->setArgOperand(1, V); - return II; + MadeChange = true; } + if (MadeChange) + return II; break; } @@ -1252,8 +1846,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { ConstantInt *CIIndex = dyn_cast<ConstantInt>(II->getArgOperand(2)); // Attempt to simplify to a constant or shuffle vector. - if (Value *V = SimplifyX86extrq(*II, Op0, CILength, CIIndex, *Builder)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86extrq(*II, Op0, CILength, CIIndex, *Builder)) + return replaceInstUsesWith(*II, V); // EXTRQI only uses the lowest 64-bits of the first 128-bit vector // operand. @@ -1281,11 +1875,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Attempt to simplify to a constant, shuffle vector or INSERTQI call. if (CI11) { - APInt V11 = CI11->getValue(); + const APInt &V11 = CI11->getValue(); APInt Len = V11.zextOrTrunc(6); APInt Idx = V11.lshr(8).zextOrTrunc(6); - if (Value *V = SimplifyX86insertq(*II, Op0, Op1, Len, Idx, *Builder)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86insertq(*II, Op0, Op1, Len, Idx, *Builder)) + return replaceInstUsesWith(*II, V); } // INSERTQ only uses the lowest 64-bits of the first 128-bit vector @@ -1317,21 +1911,23 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (CILength && CIIndex) { APInt Len = CILength->getValue().zextOrTrunc(6); APInt Idx = CIIndex->getValue().zextOrTrunc(6); - if (Value *V = SimplifyX86insertq(*II, Op0, Op1, Len, Idx, *Builder)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86insertq(*II, Op0, Op1, Len, Idx, *Builder)) + return replaceInstUsesWith(*II, V); } // INSERTQI only uses the lowest 64-bits of the first two 128-bit vector // operands. + bool MadeChange = false; if (Value *V = SimplifyDemandedVectorEltsLow(Op0, VWidth0, 1)) { II->setArgOperand(0, V); - return II; + MadeChange = true; } - if (Value *V = SimplifyDemandedVectorEltsLow(Op1, VWidth1, 1)) { II->setArgOperand(1, V); - return II; + MadeChange = true; } + if (MadeChange) + return II; break; } @@ -1352,143 +1948,87 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // fold (blend A, A, Mask) -> A if (Op0 == Op1) - return ReplaceInstUsesWith(CI, Op0); + return replaceInstUsesWith(CI, Op0); // Zero Mask - select 1st argument. if (isa<ConstantAggregateZero>(Mask)) - return ReplaceInstUsesWith(CI, Op0); + return replaceInstUsesWith(CI, Op0); // Constant Mask - select 1st/2nd argument lane based on top bit of mask. - if (auto C = dyn_cast<ConstantDataVector>(Mask)) { - auto Tyi1 = Builder->getInt1Ty(); - auto SelectorType = cast<VectorType>(Mask->getType()); - auto EltTy = SelectorType->getElementType(); - unsigned Size = SelectorType->getNumElements(); - unsigned BitWidth = - EltTy->isFloatTy() - ? 32 - : (EltTy->isDoubleTy() ? 64 : EltTy->getIntegerBitWidth()); - assert((BitWidth == 64 || BitWidth == 32 || BitWidth == 8) && - "Wrong arguments for variable blend intrinsic"); - SmallVector<Constant *, 32> Selectors; - for (unsigned I = 0; I < Size; ++I) { - // The intrinsics only read the top bit - uint64_t Selector; - if (BitWidth == 8) - Selector = C->getElementAsInteger(I); - else - Selector = C->getElementAsAPFloat(I).bitcastToAPInt().getZExtValue(); - Selectors.push_back(ConstantInt::get(Tyi1, Selector >> (BitWidth - 1))); - } - auto NewSelector = ConstantVector::get(Selectors); + if (auto *ConstantMask = dyn_cast<ConstantDataVector>(Mask)) { + Constant *NewSelector = getNegativeIsTrueBoolVec(ConstantMask); return SelectInst::Create(NewSelector, Op1, Op0, "blendv"); } break; } case Intrinsic::x86_ssse3_pshuf_b_128: - case Intrinsic::x86_avx2_pshuf_b: { - // Turn pshufb(V1,mask) -> shuffle(V1,Zero,mask) if mask is a constant. - auto *V = II->getArgOperand(1); - auto *VTy = cast<VectorType>(V->getType()); - unsigned NumElts = VTy->getNumElements(); - assert((NumElts == 16 || NumElts == 32) && - "Unexpected number of elements in shuffle mask!"); - // Initialize the resulting shuffle mask to all zeroes. - uint32_t Indexes[32] = {0}; - - if (auto *Mask = dyn_cast<ConstantDataVector>(V)) { - // Each byte in the shuffle control mask forms an index to permute the - // corresponding byte in the destination operand. - for (unsigned I = 0; I < NumElts; ++I) { - int8_t Index = Mask->getElementAsInteger(I); - // If the most significant bit (bit[7]) of each byte of the shuffle - // control mask is set, then zero is written in the result byte. - // The zero vector is in the right-hand side of the resulting - // shufflevector. - - // The value of each index is the least significant 4 bits of the - // shuffle control byte. - Indexes[I] = (Index < 0) ? NumElts : Index & 0xF; - } - } else if (!isa<ConstantAggregateZero>(V)) - break; - - // The value of each index for the high 128-bit lane is the least - // significant 4 bits of the respective shuffle control byte. - for (unsigned I = 16; I < NumElts; ++I) - Indexes[I] += I & 0xF0; - - auto NewC = ConstantDataVector::get(V->getContext(), - makeArrayRef(Indexes, NumElts)); - auto V1 = II->getArgOperand(0); - auto V2 = Constant::getNullValue(II->getType()); - auto Shuffle = Builder->CreateShuffleVector(V1, V2, NewC); - return ReplaceInstUsesWith(CI, Shuffle); - } + case Intrinsic::x86_avx2_pshuf_b: + if (Value *V = simplifyX86pshufb(*II, *Builder)) + return replaceInstUsesWith(*II, V); + break; case Intrinsic::x86_avx_vpermilvar_ps: case Intrinsic::x86_avx_vpermilvar_ps_256: case Intrinsic::x86_avx_vpermilvar_pd: - case Intrinsic::x86_avx_vpermilvar_pd_256: { - // Convert vpermil* to shufflevector if the mask is constant. - Value *V = II->getArgOperand(1); - unsigned Size = cast<VectorType>(V->getType())->getNumElements(); - assert(Size == 8 || Size == 4 || Size == 2); - uint32_t Indexes[8]; - if (auto C = dyn_cast<ConstantDataVector>(V)) { - // The intrinsics only read one or two bits, clear the rest. - for (unsigned I = 0; I < Size; ++I) { - uint32_t Index = C->getElementAsInteger(I) & 0x3; - if (II->getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_pd || - II->getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_pd_256) - Index >>= 1; - Indexes[I] = Index; - } - } else if (isa<ConstantAggregateZero>(V)) { - for (unsigned I = 0; I < Size; ++I) - Indexes[I] = 0; - } else { - break; - } - // The _256 variants are a bit trickier since the mask bits always index - // into the corresponding 128 half. In order to convert to a generic - // shuffle, we have to make that explicit. - if (II->getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_ps_256 || - II->getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_pd_256) { - for (unsigned I = Size / 2; I < Size; ++I) - Indexes[I] += Size / 2; - } - auto NewC = - ConstantDataVector::get(V->getContext(), makeArrayRef(Indexes, Size)); - auto V1 = II->getArgOperand(0); - auto V2 = UndefValue::get(V1->getType()); - auto Shuffle = Builder->CreateShuffleVector(V1, V2, NewC); - return ReplaceInstUsesWith(CI, Shuffle); - } + case Intrinsic::x86_avx_vpermilvar_pd_256: + if (Value *V = simplifyX86vpermilvar(*II, *Builder)) + return replaceInstUsesWith(*II, V); + break; + + case Intrinsic::x86_avx2_permd: + case Intrinsic::x86_avx2_permps: + if (Value *V = simplifyX86vpermv(*II, *Builder)) + return replaceInstUsesWith(*II, V); + break; case Intrinsic::x86_avx_vperm2f128_pd_256: case Intrinsic::x86_avx_vperm2f128_ps_256: case Intrinsic::x86_avx_vperm2f128_si_256: case Intrinsic::x86_avx2_vperm2i128: - if (Value *V = SimplifyX86vperm2(*II, *Builder)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86vperm2(*II, *Builder)) + return replaceInstUsesWith(*II, V); + break; + + case Intrinsic::x86_avx_maskload_ps: + case Intrinsic::x86_avx_maskload_pd: + case Intrinsic::x86_avx_maskload_ps_256: + case Intrinsic::x86_avx_maskload_pd_256: + case Intrinsic::x86_avx2_maskload_d: + case Intrinsic::x86_avx2_maskload_q: + case Intrinsic::x86_avx2_maskload_d_256: + case Intrinsic::x86_avx2_maskload_q_256: + if (Instruction *I = simplifyX86MaskedLoad(*II, *this)) + return I; + break; + + case Intrinsic::x86_sse2_maskmov_dqu: + case Intrinsic::x86_avx_maskstore_ps: + case Intrinsic::x86_avx_maskstore_pd: + case Intrinsic::x86_avx_maskstore_ps_256: + case Intrinsic::x86_avx_maskstore_pd_256: + case Intrinsic::x86_avx2_maskstore_d: + case Intrinsic::x86_avx2_maskstore_q: + case Intrinsic::x86_avx2_maskstore_d_256: + case Intrinsic::x86_avx2_maskstore_q_256: + if (simplifyX86MaskedStore(*II, *this)) + return nullptr; break; case Intrinsic::x86_xop_vpcomb: case Intrinsic::x86_xop_vpcomd: case Intrinsic::x86_xop_vpcomq: case Intrinsic::x86_xop_vpcomw: - if (Value *V = SimplifyX86vpcom(*II, *Builder, true)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86vpcom(*II, *Builder, true)) + return replaceInstUsesWith(*II, V); break; case Intrinsic::x86_xop_vpcomub: case Intrinsic::x86_xop_vpcomud: case Intrinsic::x86_xop_vpcomuq: case Intrinsic::x86_xop_vpcomuw: - if (Value *V = SimplifyX86vpcom(*II, *Builder, false)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86vpcom(*II, *Builder, false)) + return replaceInstUsesWith(*II, V); break; case Intrinsic::ppc_altivec_vperm: @@ -1585,7 +2125,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Handle mul by zero first: if (isa<ConstantAggregateZero>(Arg0) || isa<ConstantAggregateZero>(Arg1)) { - return ReplaceInstUsesWith(CI, ConstantAggregateZero::get(II->getType())); + return replaceInstUsesWith(CI, ConstantAggregateZero::get(II->getType())); } // Check for constant LHS & RHS - in this case we just simplify. @@ -1597,7 +2137,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { CV0 = ConstantExpr::getIntegerCast(CV0, NewVT, /*isSigned=*/!Zext); CV1 = ConstantExpr::getIntegerCast(CV1, NewVT, /*isSigned=*/!Zext); - return ReplaceInstUsesWith(CI, ConstantExpr::getMul(CV0, CV1)); + return replaceInstUsesWith(CI, ConstantExpr::getMul(CV0, CV1)); } // Couldn't simplify - canonicalize constant to the RHS. @@ -1615,7 +2155,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } - case Intrinsic::AMDGPU_rcp: { + case Intrinsic::amdgcn_rcp: { if (const ConstantFP *C = dyn_cast<ConstantFP>(II->getArgOperand(0))) { const APFloat &ArgVal = C->getValueAPF(); APFloat Val(ArgVal.getSemantics(), 1.0); @@ -1624,18 +2164,43 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Only do this if it was exact and therefore not dependent on the // rounding mode. if (Status == APFloat::opOK) - return ReplaceInstUsesWith(CI, ConstantFP::get(II->getContext(), Val)); + return replaceInstUsesWith(CI, ConstantFP::get(II->getContext(), Val)); } break; } + case Intrinsic::amdgcn_frexp_mant: + case Intrinsic::amdgcn_frexp_exp: { + Value *Src = II->getArgOperand(0); + if (const ConstantFP *C = dyn_cast<ConstantFP>(Src)) { + int Exp; + APFloat Significand = frexp(C->getValueAPF(), Exp, + APFloat::rmNearestTiesToEven); + + if (II->getIntrinsicID() == Intrinsic::amdgcn_frexp_mant) { + return replaceInstUsesWith(CI, ConstantFP::get(II->getContext(), + Significand)); + } + + // Match instruction special case behavior. + if (Exp == APFloat::IEK_NaN || Exp == APFloat::IEK_Inf) + Exp = 0; + + return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Exp)); + } + + if (isa<UndefValue>(Src)) + return replaceInstUsesWith(CI, UndefValue::get(II->getType())); + + break; + } case Intrinsic::stackrestore: { // If the save is right next to the restore, remove the restore. This can // happen when variable allocas are DCE'd. if (IntrinsicInst *SS = dyn_cast<IntrinsicInst>(II->getArgOperand(0))) { if (SS->getIntrinsicID() == Intrinsic::stacksave) { if (&*++SS->getIterator() == II) - return EraseInstFromFunction(CI); + return eraseInstFromFunction(CI); } } @@ -1653,8 +2218,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(BCI)) { // If there is a stackrestore below this one, remove this one. if (II->getIntrinsicID() == Intrinsic::stackrestore) - return EraseInstFromFunction(CI); - // Otherwise, ignore the intrinsic. + return eraseInstFromFunction(CI); + + // Bail if we cross over an intrinsic with side effects, such as + // llvm.stacksave, llvm.read_register, or llvm.setjmp. + if (II->mayHaveSideEffects()) { + CannotRemove = true; + break; + } } else { // If we found a non-intrinsic call, we can't remove the stack // restore. @@ -1668,42 +2239,29 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // are no allocas or calls between the restore and the return, nuke the // restore. if (!CannotRemove && (isa<ReturnInst>(TI) || isa<ResumeInst>(TI))) - return EraseInstFromFunction(CI); + return eraseInstFromFunction(CI); break; } - case Intrinsic::lifetime_start: { - // Remove trivially empty lifetime_start/end ranges, i.e. a start - // immediately followed by an end (ignoring debuginfo or other - // lifetime markers in between). - BasicBlock::iterator BI = II->getIterator(), BE = II->getParent()->end(); - for (++BI; BI != BE; ++BI) { - if (IntrinsicInst *LTE = dyn_cast<IntrinsicInst>(BI)) { - if (isa<DbgInfoIntrinsic>(LTE) || - LTE->getIntrinsicID() == Intrinsic::lifetime_start) - continue; - if (LTE->getIntrinsicID() == Intrinsic::lifetime_end) { - if (II->getOperand(0) == LTE->getOperand(0) && - II->getOperand(1) == LTE->getOperand(1)) { - EraseInstFromFunction(*LTE); - return EraseInstFromFunction(*II); - } - continue; - } - } - break; - } + case Intrinsic::lifetime_start: + if (removeTriviallyEmptyRange(*II, Intrinsic::lifetime_start, + Intrinsic::lifetime_end, *this)) + return nullptr; break; - } case Intrinsic::assume: { + Value *IIOperand = II->getArgOperand(0); + // Remove an assume if it is immediately followed by an identical assume. + if (match(II->getNextNode(), + m_Intrinsic<Intrinsic::assume>(m_Specific(IIOperand)))) + return eraseInstFromFunction(CI); + // Canonicalize assume(a && b) -> assume(a); assume(b); // Note: New assumption intrinsics created here are registered by // the InstCombineIRInserter object. - Value *IIOperand = II->getArgOperand(0), *A, *B, - *AssumeIntrinsic = II->getCalledValue(); + Value *AssumeIntrinsic = II->getCalledValue(), *A, *B; if (match(IIOperand, m_And(m_Value(A), m_Value(B)))) { Builder->CreateCall(AssumeIntrinsic, A, II->getName()); Builder->CreateCall(AssumeIntrinsic, B, II->getName()); - return EraseInstFromFunction(*II); + return eraseInstFromFunction(*II); } // assume(!(a || b)) -> assume(!a); assume(!b); if (match(IIOperand, m_Not(m_Or(m_Value(A), m_Value(B))))) { @@ -1711,7 +2269,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { II->getName()); Builder->CreateCall(AssumeIntrinsic, Builder->CreateNot(B), II->getName()); - return EraseInstFromFunction(*II); + return eraseInstFromFunction(*II); } // assume( (load addr) != null ) -> add 'nonnull' metadata to load @@ -1728,7 +2286,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (isValidAssumeForContext(II, LI, DT)) { MDNode *MD = MDNode::get(II->getContext(), None); LI->setMetadata(LLVMContext::MD_nonnull, MD); - return EraseInstFromFunction(*II); + return eraseInstFromFunction(*II); } } // TODO: apply nonnull return attributes to calls and invokes @@ -1739,7 +2297,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { APInt KnownZero(1, 0), KnownOne(1, 0); computeKnownBits(IIOperand, KnownZero, KnownOne, 0, II); if (KnownOne.isAllOnesValue()) - return EraseInstFromFunction(*II); + return eraseInstFromFunction(*II); break; } @@ -1748,46 +2306,38 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // facts about the relocate value, while being careful to // preserve relocation semantics. Value *DerivedPtr = cast<GCRelocateInst>(II)->getDerivedPtr(); - auto *GCRelocateType = cast<PointerType>(II->getType()); // Remove the relocation if unused, note that this check is required // to prevent the cases below from looping forever. if (II->use_empty()) - return EraseInstFromFunction(*II); + return eraseInstFromFunction(*II); // Undef is undef, even after relocation. // TODO: provide a hook for this in GCStrategy. This is clearly legal for // most practical collectors, but there was discussion in the review thread // about whether it was legal for all possible collectors. - if (isa<UndefValue>(DerivedPtr)) { - // gc_relocate is uncasted. Use undef of gc_relocate's type to replace it. - return ReplaceInstUsesWith(*II, UndefValue::get(GCRelocateType)); - } + if (isa<UndefValue>(DerivedPtr)) + // Use undef of gc_relocate's type to replace it. + return replaceInstUsesWith(*II, UndefValue::get(II->getType())); - // The relocation of null will be null for most any collector. - // TODO: provide a hook for this in GCStrategy. There might be some weird - // collector this property does not hold for. - if (isa<ConstantPointerNull>(DerivedPtr)) { - // gc_relocate is uncasted. Use null-pointer of gc_relocate's type to replace it. - return ReplaceInstUsesWith(*II, ConstantPointerNull::get(GCRelocateType)); - } + if (auto *PT = dyn_cast<PointerType>(II->getType())) { + // The relocation of null will be null for most any collector. + // TODO: provide a hook for this in GCStrategy. There might be some + // weird collector this property does not hold for. + if (isa<ConstantPointerNull>(DerivedPtr)) + // Use null-pointer of gc_relocate's type to replace it. + return replaceInstUsesWith(*II, ConstantPointerNull::get(PT)); - // isKnownNonNull -> nonnull attribute - if (isKnownNonNullAt(DerivedPtr, II, DT, TLI)) - II->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull); - - // isDereferenceablePointer -> deref attribute - if (isDereferenceablePointer(DerivedPtr, DL)) { - if (Argument *A = dyn_cast<Argument>(DerivedPtr)) { - uint64_t Bytes = A->getDereferenceableBytes(); - II->addDereferenceableAttr(AttributeSet::ReturnIndex, Bytes); - } + // isKnownNonNull -> nonnull attribute + if (isKnownNonNullAt(DerivedPtr, II, DT)) + II->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull); } // TODO: bitcast(relocate(p)) -> relocate(bitcast(p)) // Canonicalize on the type from the uses to the defs // TODO: relocate((gep p, C, C2, ...)) -> gep(relocate(p), C, C2, ...) + break; } } @@ -1800,8 +2350,8 @@ Instruction *InstCombiner::visitInvokeInst(InvokeInst &II) { return visitCallSite(&II); } -/// isSafeToEliminateVarargsCast - If this cast does not affect the value -/// passed through the varargs area, we can eliminate the use of the cast. +/// If this cast does not affect the value passed through the varargs area, we +/// can eliminate the use of the cast. static bool isSafeToEliminateVarargsCast(const CallSite CS, const DataLayout &DL, const CastInst *const CI, @@ -1833,26 +2383,22 @@ static bool isSafeToEliminateVarargsCast(const CallSite CS, return true; } -// Try to fold some different type of calls here. -// Currently we're only working with the checking functions, memcpy_chk, -// mempcpy_chk, memmove_chk, memset_chk, strcpy_chk, stpcpy_chk, strncpy_chk, -// strcat_chk and strncat_chk. Instruction *InstCombiner::tryOptimizeCall(CallInst *CI) { if (!CI->getCalledFunction()) return nullptr; auto InstCombineRAUW = [this](Instruction *From, Value *With) { - ReplaceInstUsesWith(*From, With); + replaceInstUsesWith(*From, With); }; LibCallSimplifier Simplifier(DL, TLI, InstCombineRAUW); if (Value *With = Simplifier.optimizeCall(CI)) { ++NumSimplified; - return CI->use_empty() ? CI : ReplaceInstUsesWith(*CI, With); + return CI->use_empty() ? CI : replaceInstUsesWith(*CI, With); } return nullptr; } -static IntrinsicInst *FindInitTrampolineFromAlloca(Value *TrampMem) { +static IntrinsicInst *findInitTrampolineFromAlloca(Value *TrampMem) { // Strip off at most one level of pointer casts, looking for an alloca. This // is good enough in practice and simpler than handling any number of casts. Value *Underlying = TrampMem->stripPointerCasts(); @@ -1891,7 +2437,7 @@ static IntrinsicInst *FindInitTrampolineFromAlloca(Value *TrampMem) { return InitTrampoline; } -static IntrinsicInst *FindInitTrampolineFromBB(IntrinsicInst *AdjustTramp, +static IntrinsicInst *findInitTrampolineFromBB(IntrinsicInst *AdjustTramp, Value *TrampMem) { // Visit all the previous instructions in the basic block, and try to find a // init.trampoline which has a direct path to the adjust.trampoline. @@ -1913,7 +2459,7 @@ static IntrinsicInst *FindInitTrampolineFromBB(IntrinsicInst *AdjustTramp, // call to llvm.init.trampoline if the call to the trampoline can be optimized // to a direct call to a function. Otherwise return NULL. // -static IntrinsicInst *FindInitTrampoline(Value *Callee) { +static IntrinsicInst *findInitTrampoline(Value *Callee) { Callee = Callee->stripPointerCasts(); IntrinsicInst *AdjustTramp = dyn_cast<IntrinsicInst>(Callee); if (!AdjustTramp || @@ -1922,15 +2468,14 @@ static IntrinsicInst *FindInitTrampoline(Value *Callee) { Value *TrampMem = AdjustTramp->getOperand(0); - if (IntrinsicInst *IT = FindInitTrampolineFromAlloca(TrampMem)) + if (IntrinsicInst *IT = findInitTrampolineFromAlloca(TrampMem)) return IT; - if (IntrinsicInst *IT = FindInitTrampolineFromBB(AdjustTramp, TrampMem)) + if (IntrinsicInst *IT = findInitTrampolineFromBB(AdjustTramp, TrampMem)) return IT; return nullptr; } -// visitCallSite - Improvements for call and invoke instructions. -// +/// Improvements for call and invoke instructions. Instruction *InstCombiner::visitCallSite(CallSite CS) { if (isAllocLikeFn(CS.getInstruction(), TLI)) @@ -1945,8 +2490,9 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { unsigned ArgNo = 0; for (Value *V : CS.args()) { - if (V->getType()->isPointerTy() && !CS.paramHasAttr(ArgNo+1, Attribute::NonNull) && - isKnownNonNullAt(V, CS.getInstruction(), DT, TLI)) + if (V->getType()->isPointerTy() && + !CS.paramHasAttr(ArgNo + 1, Attribute::NonNull) && + isKnownNonNullAt(V, CS.getInstruction(), DT)) Indices.push_back(ArgNo + 1); ArgNo++; } @@ -1968,7 +2514,16 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { if (!isa<Function>(Callee) && transformConstExprCastCall(CS)) return nullptr; - if (Function *CalleeF = dyn_cast<Function>(Callee)) + if (Function *CalleeF = dyn_cast<Function>(Callee)) { + // Remove the convergent attr on calls when the callee is not convergent. + if (CS.isConvergent() && !CalleeF->isConvergent() && + !CalleeF->isIntrinsic()) { + DEBUG(dbgs() << "Removing convergent attr from instr " + << CS.getInstruction() << "\n"); + CS.setNotConvergent(); + return CS.getInstruction(); + } + // If the call and callee calling conventions don't match, this call must // be unreachable, as the call is undefined. if (CalleeF->getCallingConv() != CS.getCallingConv() && @@ -1983,9 +2538,9 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { // If OldCall does not return void then replaceAllUsesWith undef. // This allows ValueHandlers and custom metadata to adjust itself. if (!OldCall->getType()->isVoidTy()) - ReplaceInstUsesWith(*OldCall, UndefValue::get(OldCall->getType())); + replaceInstUsesWith(*OldCall, UndefValue::get(OldCall->getType())); if (isa<CallInst>(OldCall)) - return EraseInstFromFunction(*OldCall); + return eraseInstFromFunction(*OldCall); // We cannot remove an invoke, because it would change the CFG, just // change the callee to a null pointer. @@ -1993,12 +2548,13 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { Constant::getNullValue(CalleeF->getType())); return nullptr; } + } if (isa<ConstantPointerNull>(Callee) || isa<UndefValue>(Callee)) { // If CS does not return void then replaceAllUsesWith undef. // This allows ValueHandlers and custom metadata to adjust itself. if (!CS.getInstruction()->getType()->isVoidTy()) - ReplaceInstUsesWith(*CS.getInstruction(), + replaceInstUsesWith(*CS.getInstruction(), UndefValue::get(CS.getInstruction()->getType())); if (isa<InvokeInst>(CS.getInstruction())) { @@ -2013,10 +2569,10 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { UndefValue::get(Type::getInt1PtrTy(Callee->getContext())), CS.getInstruction()); - return EraseInstFromFunction(*CS.getInstruction()); + return eraseInstFromFunction(*CS.getInstruction()); } - if (IntrinsicInst *II = FindInitTrampoline(Callee)) + if (IntrinsicInst *II = findInitTrampoline(Callee)) return transformCallThroughTrampoline(CS, II); PointerType *PTy = cast<PointerType>(Callee->getType()); @@ -2048,15 +2604,14 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { Instruction *I = tryOptimizeCall(CI); // If we changed something return the result, etc. Otherwise let // the fallthrough check. - if (I) return EraseInstFromFunction(*I); + if (I) return eraseInstFromFunction(*I); } return Changed ? CS.getInstruction() : nullptr; } -// transformConstExprCastCall - If the callee is a constexpr cast of a function, -// attempt to move the cast to the arguments of the call/invoke. -// +/// If the callee is a constexpr cast of a function, attempt to move the cast to +/// the arguments of the call/invoke. bool InstCombiner::transformConstExprCastCall(CallSite CS) { Function *Callee = dyn_cast<Function>(CS.getCalledValue()->stripPointerCasts()); @@ -2316,7 +2871,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { } if (!Caller->use_empty()) - ReplaceInstUsesWith(*Caller, NV); + replaceInstUsesWith(*Caller, NV); else if (Caller->hasValueHandle()) { if (OldRetTy == NV->getType()) ValueHandleBase::ValueIsRAUWd(Caller, NV); @@ -2326,14 +2881,12 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { ValueHandleBase::ValueIsDeleted(Caller); } - EraseInstFromFunction(*Caller); + eraseInstFromFunction(*Caller); return true; } -// transformCallThroughTrampoline - Turn a call to a function created by -// init_trampoline / adjust_trampoline intrinsic pair into a direct call to the -// underlying function. -// +/// Turn a call to a function created by init_trampoline / adjust_trampoline +/// intrinsic pair into a direct call to the underlying function. Instruction * InstCombiner::transformCallThroughTrampoline(CallSite CS, IntrinsicInst *Tramp) { @@ -2351,8 +2904,7 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, "transformCallThroughTrampoline called with incorrect CallSite."); Function *NestF =cast<Function>(Tramp->getArgOperand(1)->stripPointerCasts()); - PointerType *NestFPTy = cast<PointerType>(NestF->getType()); - FunctionType *NestFTy = cast<FunctionType>(NestFPTy->getElementType()); + FunctionType *NestFTy = cast<FunctionType>(NestF->getValueType()); const AttributeSet &NestAttrs = NestF->getAttributes(); if (!NestAttrs.isEmpty()) { @@ -2412,7 +2964,8 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, Idx + (Idx >= NestIdx), B)); } - ++Idx, ++I; + ++Idx; + ++I; } while (1); } @@ -2446,7 +2999,8 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, // Add the original type. NewTypes.push_back(*I); - ++Idx, ++I; + ++Idx; + ++I; } while (1); } @@ -2461,15 +3015,18 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, const AttributeSet &NewPAL = AttributeSet::get(FTy->getContext(), NewAttrs); + SmallVector<OperandBundleDef, 1> OpBundles; + CS.getOperandBundlesAsDefs(OpBundles); + Instruction *NewCaller; if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) { NewCaller = InvokeInst::Create(NewCallee, II->getNormalDest(), II->getUnwindDest(), - NewArgs); + NewArgs, OpBundles); cast<InvokeInst>(NewCaller)->setCallingConv(II->getCallingConv()); cast<InvokeInst>(NewCaller)->setAttributes(NewPAL); } else { - NewCaller = CallInst::Create(NewCallee, NewArgs); + NewCaller = CallInst::Create(NewCallee, NewArgs, OpBundles); if (cast<CallInst>(Caller)->isTailCall()) cast<CallInst>(NewCaller)->setTailCall(); cast<CallInst>(NewCaller)-> diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index 0f01d183b1adc..20556157188f4 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -149,9 +149,9 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, // New is the allocation instruction, pointer typed. AI is the original // allocation instruction, also pointer typed. Thus, cast to use is BitCast. Value *NewCast = AllocaBuilder.CreateBitCast(New, AI.getType(), "tmpcast"); - ReplaceInstUsesWith(AI, NewCast); + replaceInstUsesWith(AI, NewCast); } - return ReplaceInstUsesWith(CI, New); + return replaceInstUsesWith(CI, New); } /// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns @@ -508,7 +508,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { " to avoid cast: " << CI << '\n'); Value *Res = EvaluateInDifferentType(Src, DestTy, false); assert(Res->getType() == DestTy); - return ReplaceInstUsesWith(CI, Res); + return replaceInstUsesWith(CI, Res); } // Canonicalize trunc x to i1 -> (icmp ne (and x, 1), 0), likewise for vector. @@ -532,7 +532,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // If the shift amount is larger than the size of A, then the result is // known to be zero because all the input bits got shifted out. if (Cst->getZExtValue() >= ASize) - return ReplaceInstUsesWith(CI, Constant::getNullValue(DestTy)); + return replaceInstUsesWith(CI, Constant::getNullValue(DestTy)); // Since we're doing an lshr and a zero extend, and know that the shift // amount is smaller than ASize, it is always safe to do the shift in A's @@ -606,7 +606,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, In = Builder->CreateXor(In, One, In->getName() + ".not"); } - return ReplaceInstUsesWith(CI, In); + return replaceInstUsesWith(CI, In); } // zext (X == 0) to i32 --> X^1 iff X has only the low bit set. @@ -636,7 +636,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, Constant *Res = ConstantInt::get(Type::getInt1Ty(CI.getContext()), isNE); Res = ConstantExpr::getZExt(Res, CI.getType()); - return ReplaceInstUsesWith(CI, Res); + return replaceInstUsesWith(CI, Res); } uint32_t ShAmt = KnownZeroMask.logBase2(); @@ -654,7 +654,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, } if (CI.getType() == In->getType()) - return ReplaceInstUsesWith(CI, In); + return replaceInstUsesWith(CI, In); return CastInst::CreateIntegerCast(In, CI.getType(), false/*ZExt*/); } } @@ -694,7 +694,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, if (ICI->getPredicate() == ICmpInst::ICMP_EQ) Result = Builder->CreateXor(Result, ConstantInt::get(ITy, 1)); Result->takeName(ICI); - return ReplaceInstUsesWith(CI, Result); + return replaceInstUsesWith(CI, Result); } } } @@ -872,7 +872,7 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { APInt::getHighBitsSet(DestBitSize, DestBitSize-SrcBitsKept), 0, &CI)) - return ReplaceInstUsesWith(CI, Res); + return replaceInstUsesWith(CI, Res); // We need to emit an AND to clear the high bits. Constant *C = ConstantInt::get(Res->getType(), @@ -986,7 +986,7 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) { if (Pred == ICmpInst::ICMP_SGT) In = Builder->CreateNot(In, In->getName()+".not"); - return ReplaceInstUsesWith(CI, In); + return replaceInstUsesWith(CI, In); } } @@ -1009,7 +1009,7 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) { Value *V = Pred == ICmpInst::ICMP_NE ? ConstantInt::getAllOnesValue(CI.getType()) : ConstantInt::getNullValue(CI.getType()); - return ReplaceInstUsesWith(CI, V); + return replaceInstUsesWith(CI, V); } if (!Op1C->isZero() == (Pred == ICmpInst::ICMP_NE)) { @@ -1041,7 +1041,7 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) { } if (CI.getType() == In->getType()) - return ReplaceInstUsesWith(CI, In); + return replaceInstUsesWith(CI, In); return CastInst::CreateIntegerCast(In, CI.getType(), true/*SExt*/); } } @@ -1137,7 +1137,7 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { ComputeSignBit(Src, KnownZero, KnownOne, 0, &CI); if (KnownZero) { Value *ZExt = Builder->CreateZExt(Src, DestTy); - return ReplaceInstUsesWith(CI, ZExt); + return replaceInstUsesWith(CI, ZExt); } // Attempt to extend the entire input expression tree to the destination @@ -1158,7 +1158,7 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { // If the high bits are already filled with sign bit, just replace this // cast with the result. if (ComputeNumSignBits(Res, 0, &CI) > DestBitSize - SrcBitSize) - return ReplaceInstUsesWith(CI, Res); + return replaceInstUsesWith(CI, Res); // We need to emit a shl + ashr to do the sign extend. Value *ShAmt = ConstantInt::get(DestTy, DestBitSize-SrcBitSize); @@ -1400,8 +1400,11 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { Function *Overload = Intrinsic::getDeclaration( CI.getModule(), II->getIntrinsicID(), IntrinsicType); + SmallVector<OperandBundleDef, 1> OpBundles; + II->getOperandBundlesAsDefs(OpBundles); + Value *Args[] = { InnerTrunc }; - return CallInst::Create(Overload, Args, II->getName()); + return CallInst::Create(Overload, Args, OpBundles, II->getName()); } } } @@ -1451,7 +1454,7 @@ Instruction *InstCombiner::FoldItoFPtoI(Instruction &FI) { if (FITy->getScalarSizeInBits() < SrcTy->getScalarSizeInBits()) return new TruncInst(SrcI, FITy); if (SrcTy == FITy) - return ReplaceInstUsesWith(FI, SrcI); + return replaceInstUsesWith(FI, SrcI); return new BitCastInst(SrcI, FITy); } return nullptr; @@ -1796,7 +1799,7 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // Get rid of casts from one type to the same type. These are useless and can // be replaced by the operand. if (DestTy == Src->getType()) - return ReplaceInstUsesWith(CI, Src); + return replaceInstUsesWith(CI, Src); if (PointerType *DstPTy = dyn_cast<PointerType>(DestTy)) { PointerType *SrcPTy = cast<PointerType>(SrcTy); @@ -1811,6 +1814,13 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { if (Instruction *V = PromoteCastOfAllocation(CI, *AI)) return V; + // When the type pointed to is not sized the cast cannot be + // turned into a gep. + Type *PointeeType = + cast<PointerType>(Src->getType()->getScalarType())->getElementType(); + if (!PointeeType->isSized()) + return nullptr; + // If the source and destination are pointers, and this cast is equivalent // to a getelementptr X, 0, 0, 0... turn it into the appropriate gep. // This can enhance SROA and other transforms that want type-safe pointers. @@ -1854,7 +1864,7 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // assemble the elements of the vector manually. Try to rip the code out // and replace it with insertelements. if (Value *V = optimizeIntegerToVectorInsertions(CI, *this)) - return ReplaceInstUsesWith(CI, V); + return replaceInstUsesWith(CI, V); } } diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index d9311a343eadb..bfd73f4bbac5d 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -13,18 +13,19 @@ #include "InstCombineInternal.h" #include "llvm/ADT/APSInt.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "llvm/Analysis/TargetLibraryInfo.h" using namespace llvm; using namespace PatternMatch; @@ -55,8 +56,8 @@ static bool HasAddOverflow(ConstantInt *Result, return Result->getValue().slt(In1->getValue()); } -/// AddWithOverflow - Compute Result = In1+In2, returning true if the result -/// overflowed for this type. +/// Compute Result = In1+In2, returning true if the result overflowed for this +/// type. static bool AddWithOverflow(Constant *&Result, Constant *In1, Constant *In2, bool IsSigned = false) { Result = ConstantExpr::getAdd(In1, In2); @@ -90,8 +91,8 @@ static bool HasSubOverflow(ConstantInt *Result, return Result->getValue().sgt(In1->getValue()); } -/// SubWithOverflow - Compute Result = In1-In2, returning true if the result -/// overflowed for this type. +/// Compute Result = In1-In2, returning true if the result overflowed for this +/// type. static bool SubWithOverflow(Constant *&Result, Constant *In1, Constant *In2, bool IsSigned = false) { Result = ConstantExpr::getSub(In1, In2); @@ -113,13 +114,21 @@ static bool SubWithOverflow(Constant *&Result, Constant *In1, IsSigned); } -/// isSignBitCheck - Given an exploded icmp instruction, return true if the -/// comparison only checks the sign bit. If it only checks the sign bit, set -/// TrueIfSigned if the result of the comparison is true when the input value is -/// signed. -static bool isSignBitCheck(ICmpInst::Predicate pred, ConstantInt *RHS, +/// Given an icmp instruction, return true if any use of this comparison is a +/// branch on sign bit comparison. +static bool isBranchOnSignBitCheck(ICmpInst &I, bool isSignBit) { + for (auto *U : I.users()) + if (isa<BranchInst>(U)) + return isSignBit; + return false; +} + +/// Given an exploded icmp instruction, return true if the comparison only +/// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if the +/// result of the comparison is true when the input value is signed. +static bool isSignBitCheck(ICmpInst::Predicate Pred, ConstantInt *RHS, bool &TrueIfSigned) { - switch (pred) { + switch (Pred) { case ICmpInst::ICMP_SLT: // True if LHS s< 0 TrueIfSigned = true; return RHS->isZero(); @@ -145,21 +154,21 @@ static bool isSignBitCheck(ICmpInst::Predicate pred, ConstantInt *RHS, /// Returns true if the exploded icmp can be expressed as a signed comparison /// to zero and updates the predicate accordingly. /// The signedness of the comparison is preserved. -static bool isSignTest(ICmpInst::Predicate &pred, const ConstantInt *RHS) { - if (!ICmpInst::isSigned(pred)) +static bool isSignTest(ICmpInst::Predicate &Pred, const ConstantInt *RHS) { + if (!ICmpInst::isSigned(Pred)) return false; if (RHS->isZero()) - return ICmpInst::isRelational(pred); + return ICmpInst::isRelational(Pred); if (RHS->isOne()) { - if (pred == ICmpInst::ICMP_SLT) { - pred = ICmpInst::ICMP_SLE; + if (Pred == ICmpInst::ICMP_SLT) { + Pred = ICmpInst::ICMP_SLE; return true; } } else if (RHS->isAllOnesValue()) { - if (pred == ICmpInst::ICMP_SGT) { - pred = ICmpInst::ICMP_SGE; + if (Pred == ICmpInst::ICMP_SGT) { + Pred = ICmpInst::ICMP_SGE; return true; } } @@ -167,19 +176,18 @@ static bool isSignTest(ICmpInst::Predicate &pred, const ConstantInt *RHS) { return false; } -// isHighOnes - Return true if the constant is of the form 1+0+. -// This is the same as lowones(~X). +/// Return true if the constant is of the form 1+0+. This is the same as +/// lowones(~X). static bool isHighOnes(const ConstantInt *CI) { return (~CI->getValue() + 1).isPowerOf2(); } -/// ComputeSignedMinMaxValuesFromKnownBits - Given a signed integer type and a -/// set of known zero and one bits, compute the maximum and minimum values that -/// could have the specified known zero and known one bits, returning them in -/// min/max. -static void ComputeSignedMinMaxValuesFromKnownBits(const APInt& KnownZero, - const APInt& KnownOne, - APInt& Min, APInt& Max) { +/// Given a signed integer type and a set of known zero and one bits, compute +/// the maximum and minimum values that could have the specified known zero and +/// known one bits, returning them in Min/Max. +static void ComputeSignedMinMaxValuesFromKnownBits(const APInt &KnownZero, + const APInt &KnownOne, + APInt &Min, APInt &Max) { assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() && KnownZero.getBitWidth() == Min.getBitWidth() && KnownZero.getBitWidth() == Max.getBitWidth() && @@ -197,10 +205,9 @@ static void ComputeSignedMinMaxValuesFromKnownBits(const APInt& KnownZero, } } -// ComputeUnsignedMinMaxValuesFromKnownBits - Given an unsigned integer type and -// a set of known zero and one bits, compute the maximum and minimum values that -// could have the specified known zero and known one bits, returning them in -// min/max. +/// Given an unsigned integer type and a set of known zero and one bits, compute +/// the maximum and minimum values that could have the specified known zero and +/// known one bits, returning them in Min/Max. static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero, const APInt &KnownOne, APInt &Min, APInt &Max) { @@ -216,14 +223,14 @@ static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero, Max = KnownOne|UnknownBits; } -/// FoldCmpLoadFromIndexedGlobal - Called we see this pattern: +/// This is called when we see this pattern: /// cmp pred (load (gep GV, ...)), cmpcst -/// where GV is a global variable with a constant initializer. Try to simplify -/// this into some simple computation that does not need the load. For example +/// where GV is a global variable with a constant initializer. Try to simplify +/// this into some simple computation that does not need the load. For example /// we can optimize "icmp eq (load (gep "foo", 0, i)), 0" into "icmp eq i, 3". /// /// If AndCst is non-null, then the loaded value is masked with that constant -/// before doing the comparison. This handles cases like "A[i]&4 == 0". +/// before doing the comparison. This handles cases like "A[i]&4 == 0". Instruction *InstCombiner:: FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI, ConstantInt *AndCst) { @@ -401,7 +408,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, if (SecondTrueElement != Overdefined) { // None true -> false. if (FirstTrueElement == Undefined) - return ReplaceInstUsesWith(ICI, Builder->getFalse()); + return replaceInstUsesWith(ICI, Builder->getFalse()); Value *FirstTrueIdx = ConstantInt::get(Idx->getType(), FirstTrueElement); @@ -421,7 +428,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, if (SecondFalseElement != Overdefined) { // None false -> true. if (FirstFalseElement == Undefined) - return ReplaceInstUsesWith(ICI, Builder->getTrue()); + return replaceInstUsesWith(ICI, Builder->getTrue()); Value *FirstFalseIdx = ConstantInt::get(Idx->getType(), FirstFalseElement); @@ -492,12 +499,12 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, return nullptr; } -/// EvaluateGEPOffsetExpression - Return a value that can be used to compare -/// the *offset* implied by a GEP to zero. For example, if we have &A[i], we -/// want to return 'i' for "icmp ne i, 0". Note that, in general, indices can -/// be complex, and scales are involved. The above expression would also be -/// legal to codegen as "icmp ne (i*4), 0" (assuming A is a pointer to i32). -/// This later form is less amenable to optimization though, and we are allowed +/// Return a value that can be used to compare the *offset* implied by a GEP to +/// zero. For example, if we have &A[i], we want to return 'i' for +/// "icmp ne i, 0". Note that, in general, indices can be complex, and scales +/// are involved. The above expression would also be legal to codegen as +/// "icmp ne (i*4), 0" (assuming A is a pointer to i32). +/// This latter form is less amenable to optimization though, and we are allowed /// to generate the first by knowing that pointer arithmetic doesn't overflow. /// /// If we can't emit an optimized form for this expression, this returns null. @@ -595,8 +602,323 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, return IC.Builder->CreateAdd(VariableIdx, OffsetVal, "offset"); } -/// FoldGEPICmp - Fold comparisons between a GEP instruction and something -/// else. At this point we know that the GEP is on the LHS of the comparison. +/// Returns true if we can rewrite Start as a GEP with pointer Base +/// and some integer offset. The nodes that need to be re-written +/// for this transformation will be added to Explored. +static bool canRewriteGEPAsOffset(Value *Start, Value *Base, + const DataLayout &DL, + SetVector<Value *> &Explored) { + SmallVector<Value *, 16> WorkList(1, Start); + Explored.insert(Base); + + // The following traversal gives us an order which can be used + // when doing the final transformation. Since in the final + // transformation we create the PHI replacement instructions first, + // we don't have to get them in any particular order. + // + // However, for other instructions we will have to traverse the + // operands of an instruction first, which means that we have to + // do a post-order traversal. + while (!WorkList.empty()) { + SetVector<PHINode *> PHIs; + + while (!WorkList.empty()) { + if (Explored.size() >= 100) + return false; + + Value *V = WorkList.back(); + + if (Explored.count(V) != 0) { + WorkList.pop_back(); + continue; + } + + if (!isa<IntToPtrInst>(V) && !isa<PtrToIntInst>(V) && + !isa<GEPOperator>(V) && !isa<PHINode>(V)) + // We've found some value that we can't explore which is different from + // the base. Therefore we can't do this transformation. + return false; + + if (isa<IntToPtrInst>(V) || isa<PtrToIntInst>(V)) { + auto *CI = dyn_cast<CastInst>(V); + if (!CI->isNoopCast(DL)) + return false; + + if (Explored.count(CI->getOperand(0)) == 0) + WorkList.push_back(CI->getOperand(0)); + } + + if (auto *GEP = dyn_cast<GEPOperator>(V)) { + // We're limiting the GEP to having one index. This will preserve + // the original pointer type. We could handle more cases in the + // future. + if (GEP->getNumIndices() != 1 || !GEP->isInBounds() || + GEP->getType() != Start->getType()) + return false; + + if (Explored.count(GEP->getOperand(0)) == 0) + WorkList.push_back(GEP->getOperand(0)); + } + + if (WorkList.back() == V) { + WorkList.pop_back(); + // We've finished visiting this node, mark it as such. + Explored.insert(V); + } + + if (auto *PN = dyn_cast<PHINode>(V)) { + // We cannot transform PHIs on unsplittable basic blocks. + if (isa<CatchSwitchInst>(PN->getParent()->getTerminator())) + return false; + Explored.insert(PN); + PHIs.insert(PN); + } + } + + // Explore the PHI nodes further. + for (auto *PN : PHIs) + for (Value *Op : PN->incoming_values()) + if (Explored.count(Op) == 0) + WorkList.push_back(Op); + } + + // Make sure that we can do this. Since we can't insert GEPs in a basic + // block before a PHI node, we can't easily do this transformation if + // we have PHI node users of transformed instructions. + for (Value *Val : Explored) { + for (Value *Use : Val->uses()) { + + auto *PHI = dyn_cast<PHINode>(Use); + auto *Inst = dyn_cast<Instruction>(Val); + + if (Inst == Base || Inst == PHI || !Inst || !PHI || + Explored.count(PHI) == 0) + continue; + + if (PHI->getParent() == Inst->getParent()) + return false; + } + } + return true; +} + +// Sets the appropriate insert point on Builder where we can add +// a replacement Instruction for V (if that is possible). +static void setInsertionPoint(IRBuilder<> &Builder, Value *V, + bool Before = true) { + if (auto *PHI = dyn_cast<PHINode>(V)) { + Builder.SetInsertPoint(&*PHI->getParent()->getFirstInsertionPt()); + return; + } + if (auto *I = dyn_cast<Instruction>(V)) { + if (!Before) + I = &*std::next(I->getIterator()); + Builder.SetInsertPoint(I); + return; + } + if (auto *A = dyn_cast<Argument>(V)) { + // Set the insertion point in the entry block. + BasicBlock &Entry = A->getParent()->getEntryBlock(); + Builder.SetInsertPoint(&*Entry.getFirstInsertionPt()); + return; + } + // Otherwise, this is a constant and we don't need to set a new + // insertion point. + assert(isa<Constant>(V) && "Setting insertion point for unknown value!"); +} + +/// Returns a re-written value of Start as an indexed GEP using Base as a +/// pointer. +static Value *rewriteGEPAsOffset(Value *Start, Value *Base, + const DataLayout &DL, + SetVector<Value *> &Explored) { + // Perform all the substitutions. This is a bit tricky because we can + // have cycles in our use-def chains. + // 1. Create the PHI nodes without any incoming values. + // 2. Create all the other values. + // 3. Add the edges for the PHI nodes. + // 4. Emit GEPs to get the original pointers. + // 5. Remove the original instructions. + Type *IndexType = IntegerType::get( + Base->getContext(), DL.getPointerTypeSizeInBits(Start->getType())); + + DenseMap<Value *, Value *> NewInsts; + NewInsts[Base] = ConstantInt::getNullValue(IndexType); + + // Create the new PHI nodes, without adding any incoming values. + for (Value *Val : Explored) { + if (Val == Base) + continue; + // Create empty phi nodes. This avoids cyclic dependencies when creating + // the remaining instructions. + if (auto *PHI = dyn_cast<PHINode>(Val)) + NewInsts[PHI] = PHINode::Create(IndexType, PHI->getNumIncomingValues(), + PHI->getName() + ".idx", PHI); + } + IRBuilder<> Builder(Base->getContext()); + + // Create all the other instructions. + for (Value *Val : Explored) { + + if (NewInsts.find(Val) != NewInsts.end()) + continue; + + if (auto *CI = dyn_cast<CastInst>(Val)) { + NewInsts[CI] = NewInsts[CI->getOperand(0)]; + continue; + } + if (auto *GEP = dyn_cast<GEPOperator>(Val)) { + Value *Index = NewInsts[GEP->getOperand(1)] ? NewInsts[GEP->getOperand(1)] + : GEP->getOperand(1); + setInsertionPoint(Builder, GEP); + // Indices might need to be sign extended. GEPs will magically do + // this, but we need to do it ourselves here. + if (Index->getType()->getScalarSizeInBits() != + NewInsts[GEP->getOperand(0)]->getType()->getScalarSizeInBits()) { + Index = Builder.CreateSExtOrTrunc( + Index, NewInsts[GEP->getOperand(0)]->getType(), + GEP->getOperand(0)->getName() + ".sext"); + } + + auto *Op = NewInsts[GEP->getOperand(0)]; + if (isa<ConstantInt>(Op) && dyn_cast<ConstantInt>(Op)->isZero()) + NewInsts[GEP] = Index; + else + NewInsts[GEP] = Builder.CreateNSWAdd( + Op, Index, GEP->getOperand(0)->getName() + ".add"); + continue; + } + if (isa<PHINode>(Val)) + continue; + + llvm_unreachable("Unexpected instruction type"); + } + + // Add the incoming values to the PHI nodes. + for (Value *Val : Explored) { + if (Val == Base) + continue; + // All the instructions have been created, we can now add edges to the + // phi nodes. + if (auto *PHI = dyn_cast<PHINode>(Val)) { + PHINode *NewPhi = static_cast<PHINode *>(NewInsts[PHI]); + for (unsigned I = 0, E = PHI->getNumIncomingValues(); I < E; ++I) { + Value *NewIncoming = PHI->getIncomingValue(I); + + if (NewInsts.find(NewIncoming) != NewInsts.end()) + NewIncoming = NewInsts[NewIncoming]; + + NewPhi->addIncoming(NewIncoming, PHI->getIncomingBlock(I)); + } + } + } + + for (Value *Val : Explored) { + if (Val == Base) + continue; + + // Depending on the type, for external users we have to emit + // a GEP or a GEP + ptrtoint. + setInsertionPoint(Builder, Val, false); + + // If required, create an inttoptr instruction for Base. + Value *NewBase = Base; + if (!Base->getType()->isPointerTy()) + NewBase = Builder.CreateBitOrPointerCast(Base, Start->getType(), + Start->getName() + "to.ptr"); + + Value *GEP = Builder.CreateInBoundsGEP( + Start->getType()->getPointerElementType(), NewBase, + makeArrayRef(NewInsts[Val]), Val->getName() + ".ptr"); + + if (!Val->getType()->isPointerTy()) { + Value *Cast = Builder.CreatePointerCast(GEP, Val->getType(), + Val->getName() + ".conv"); + GEP = Cast; + } + Val->replaceAllUsesWith(GEP); + } + + return NewInsts[Start]; +} + +/// Looks through GEPs, IntToPtrInsts and PtrToIntInsts in order to express +/// the input Value as a constant indexed GEP. Returns a pair containing +/// the GEPs Pointer and Index. +static std::pair<Value *, Value *> +getAsConstantIndexedAddress(Value *V, const DataLayout &DL) { + Type *IndexType = IntegerType::get(V->getContext(), + DL.getPointerTypeSizeInBits(V->getType())); + + Constant *Index = ConstantInt::getNullValue(IndexType); + while (true) { + if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { + // We accept only inbouds GEPs here to exclude the possibility of + // overflow. + if (!GEP->isInBounds()) + break; + if (GEP->hasAllConstantIndices() && GEP->getNumIndices() == 1 && + GEP->getType() == V->getType()) { + V = GEP->getOperand(0); + Constant *GEPIndex = static_cast<Constant *>(GEP->getOperand(1)); + Index = ConstantExpr::getAdd( + Index, ConstantExpr::getSExtOrBitCast(GEPIndex, IndexType)); + continue; + } + break; + } + if (auto *CI = dyn_cast<IntToPtrInst>(V)) { + if (!CI->isNoopCast(DL)) + break; + V = CI->getOperand(0); + continue; + } + if (auto *CI = dyn_cast<PtrToIntInst>(V)) { + if (!CI->isNoopCast(DL)) + break; + V = CI->getOperand(0); + continue; + } + break; + } + return {V, Index}; +} + +/// Converts (CMP GEPLHS, RHS) if this change would make RHS a constant. +/// We can look through PHIs, GEPs and casts in order to determine a common base +/// between GEPLHS and RHS. +static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, + ICmpInst::Predicate Cond, + const DataLayout &DL) { + if (!GEPLHS->hasAllConstantIndices()) + return nullptr; + + Value *PtrBase, *Index; + std::tie(PtrBase, Index) = getAsConstantIndexedAddress(GEPLHS, DL); + + // The set of nodes that will take part in this transformation. + SetVector<Value *> Nodes; + + if (!canRewriteGEPAsOffset(RHS, PtrBase, DL, Nodes)) + return nullptr; + + // We know we can re-write this as + // ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) + // Since we've only looked through inbouds GEPs we know that we + // can't have overflow on either side. We can therefore re-write + // this as: + // OFFSET1 cmp OFFSET2 + Value *NewRHS = rewriteGEPAsOffset(RHS, PtrBase, DL, Nodes); + + // RewriteGEPAsOffset has replaced RHS and all of its uses with a re-written + // GEP having PtrBase as the pointer base, and has returned in NewRHS the + // offset. Since Index is the offset of LHS to the base pointer, we will now + // compare the offsets instead of comparing the pointers. + return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Index, NewRHS); +} + +/// Fold comparisons between a GEP instruction and something else. At this point +/// we know that the GEP is on the LHS of the comparison. Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, ICmpInst::Predicate Cond, Instruction &I) { @@ -670,12 +992,13 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, Value *Cmp = Builder->CreateICmp(ICmpInst::getSignedPredicate(Cond), LOffset, ROffset); - return ReplaceInstUsesWith(I, Cmp); + return replaceInstUsesWith(I, Cmp); } // Otherwise, the base pointers are different and the indices are - // different, bail out. - return nullptr; + // different. Try convert this to an indexed compare by looking through + // PHIs/casts. + return transformToIndexedCompare(GEPLHS, RHS, Cond, DL); } // If one of the GEPs has all zero indices, recurse. @@ -706,7 +1029,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, } if (NumDifferences == 0) // SAME GEP? - return ReplaceInstUsesWith(I, // No comparison is needed here. + return replaceInstUsesWith(I, // No comparison is needed here. Builder->getInt1(ICmpInst::isTrueWhenEqual(Cond))); else if (NumDifferences == 1 && GEPsInBounds) { @@ -727,7 +1050,10 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, return new ICmpInst(ICmpInst::getSignedPredicate(Cond), L, R); } } - return nullptr; + + // Try convert this to an indexed compare by looking through PHIs/casts as a + // last resort. + return transformToIndexedCompare(GEPLHS, RHS, Cond, DL); } Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, @@ -802,12 +1128,12 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, } Type *CmpTy = CmpInst::makeCmpResultType(Other->getType()); - return ReplaceInstUsesWith( + return replaceInstUsesWith( ICI, ConstantInt::get(CmpTy, !CmpInst::isTrueWhenEqual(ICI.getPredicate()))); } -/// FoldICmpAddOpCst - Fold "icmp pred (X+CI), X". +/// Fold "icmp pred (X+CI), X". Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI, Value *X, ConstantInt *CI, ICmpInst::Predicate Pred) { @@ -855,8 +1181,8 @@ Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI, return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantExpr::getSub(SMax, C)); } -/// FoldICmpDivCst - Fold "icmp pred, ([su]div X, DivRHS), CmpRHS" where DivRHS -/// and CmpRHS are both known to be integer constants. +/// Fold "icmp pred, ([su]div X, DivRHS), CmpRHS" where DivRHS and CmpRHS are +/// both known to be integer constants. Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, ConstantInt *DivRHS) { ConstantInt *CmpRHS = cast<ConstantInt>(ICI.getOperand(1)); @@ -898,8 +1224,8 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, // Get the ICmp opcode ICmpInst::Predicate Pred = ICI.getPredicate(); - /// If the division is known to be exact, then there is no remainder from the - /// divide, so the covered range size is unit, otherwise it is the divisor. + // If the division is known to be exact, then there is no remainder from the + // divide, so the covered range size is unit, otherwise it is the divisor. ConstantInt *RangeSize = DivI->isExact() ? getOne(Prod) : DivRHS; // Figure out the interval that is being checked. For example, a comparison @@ -973,46 +1299,46 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, default: llvm_unreachable("Unhandled icmp opcode!"); case ICmpInst::ICMP_EQ: if (LoOverflow && HiOverflow) - return ReplaceInstUsesWith(ICI, Builder->getFalse()); + return replaceInstUsesWith(ICI, Builder->getFalse()); if (HiOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, X, LoBound); if (LoOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, X, HiBound); - return ReplaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound, + return replaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound, DivIsSigned, true)); case ICmpInst::ICMP_NE: if (LoOverflow && HiOverflow) - return ReplaceInstUsesWith(ICI, Builder->getTrue()); + return replaceInstUsesWith(ICI, Builder->getTrue()); if (HiOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, X, LoBound); if (LoOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, X, HiBound); - return ReplaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound, + return replaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound, DivIsSigned, false)); case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_SLT: if (LoOverflow == +1) // Low bound is greater than input range. - return ReplaceInstUsesWith(ICI, Builder->getTrue()); + return replaceInstUsesWith(ICI, Builder->getTrue()); if (LoOverflow == -1) // Low bound is less than input range. - return ReplaceInstUsesWith(ICI, Builder->getFalse()); + return replaceInstUsesWith(ICI, Builder->getFalse()); return new ICmpInst(Pred, X, LoBound); case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_SGT: if (HiOverflow == +1) // High bound greater than input range. - return ReplaceInstUsesWith(ICI, Builder->getFalse()); + return replaceInstUsesWith(ICI, Builder->getFalse()); if (HiOverflow == -1) // High bound less than input range. - return ReplaceInstUsesWith(ICI, Builder->getTrue()); + return replaceInstUsesWith(ICI, Builder->getTrue()); if (Pred == ICmpInst::ICMP_UGT) return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound); return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); } } -/// FoldICmpShrCst - Handle "icmp(([al]shr X, cst1), cst2)". +/// Handle "icmp(([al]shr X, cst1), cst2)". Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, ConstantInt *ShAmt) { const APInt &CmpRHSV = cast<ConstantInt>(ICI.getOperand(1))->getValue(); @@ -1077,7 +1403,7 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, if (Comp != CmpRHSV) { // Comparing against a bit that we know is zero. bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; Constant *Cst = Builder->getInt1(IsICMP_NE); - return ReplaceInstUsesWith(ICI, Cst); + return replaceInstUsesWith(ICI, Cst); } // Otherwise, check to see if the bits shifted out are known to be zero. @@ -1098,7 +1424,7 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, return nullptr; } -/// FoldICmpCstShrCst - Handle "(icmp eq/ne (ashr/lshr const2, A), const1)" -> +/// Handle "(icmp eq/ne (ashr/lshr const2, A), const1)" -> /// (icmp eq/ne A, Log2(const2/const1)) -> /// (icmp eq/ne A, Log2(const2) - Log2(const1)). Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, @@ -1109,7 +1435,7 @@ Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, auto getConstant = [&I, this](bool IsTrue) { if (I.getPredicate() == I.ICMP_NE) IsTrue = !IsTrue; - return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); + return replaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); }; auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { @@ -1118,8 +1444,8 @@ Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, return new ICmpInst(Pred, LHS, RHS); }; - APInt AP1 = CI1->getValue(); - APInt AP2 = CI2->getValue(); + const APInt &AP1 = CI1->getValue(); + const APInt &AP2 = CI2->getValue(); // Don't bother doing any work for cases which InstSimplify handles. if (AP2 == 0) @@ -1163,7 +1489,7 @@ Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, return getConstant(false); } -/// FoldICmpCstShlCst - Handle "(icmp eq/ne (shl const2, A), const1)" -> +/// Handle "(icmp eq/ne (shl const2, A), const1)" -> /// (icmp eq/ne A, TrailingZeros(const1) - TrailingZeros(const2)). Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, ConstantInt *CI1, @@ -1173,7 +1499,7 @@ Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, auto getConstant = [&I, this](bool IsTrue) { if (I.getPredicate() == I.ICMP_NE) IsTrue = !IsTrue; - return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); + return replaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); }; auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { @@ -1182,8 +1508,8 @@ Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, return new ICmpInst(Pred, LHS, RHS); }; - APInt AP1 = CI1->getValue(); - APInt AP2 = CI2->getValue(); + const APInt &AP1 = CI1->getValue(); + const APInt &AP2 = CI2->getValue(); // Don't bother doing any work for cases which InstSimplify handles. if (AP2 == 0) @@ -1208,8 +1534,7 @@ Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, return getConstant(false); } -/// visitICmpInstWithInstAndIntCst - Handle "icmp (instr, intcst)". -/// +/// Handle "icmp (instr, intcst)". Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, Instruction *LHSI, ConstantInt *RHS) { @@ -1412,9 +1737,9 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, // As a special case, check to see if this means that the // result is always true or false now. if (ICI.getPredicate() == ICmpInst::ICMP_EQ) - return ReplaceInstUsesWith(ICI, Builder->getFalse()); + return replaceInstUsesWith(ICI, Builder->getFalse()); if (ICI.getPredicate() == ICmpInst::ICMP_NE) - return ReplaceInstUsesWith(ICI, Builder->getTrue()); + return replaceInstUsesWith(ICI, Builder->getTrue()); } else { ICI.setOperand(1, NewCst); Constant *NewAndCst; @@ -1674,7 +1999,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, if (Comp != RHS) {// Comparing against a bit that we know is zero. bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; Constant *Cst = Builder->getInt1(IsICMP_NE); - return ReplaceInstUsesWith(ICI, Cst); + return replaceInstUsesWith(ICI, Cst); } // If the shift is NUW, then it is just shifting out zeros, no need for an @@ -1764,8 +2089,28 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, break; } - case Instruction::SDiv: case Instruction::UDiv: + if (ConstantInt *DivLHS = dyn_cast<ConstantInt>(LHSI->getOperand(0))) { + Value *X = LHSI->getOperand(1); + const APInt &C1 = RHS->getValue(); + const APInt &C2 = DivLHS->getValue(); + assert(C2 != 0 && "udiv 0, X should have been simplified already."); + // (icmp ugt (udiv C2, X), C1) -> (icmp ule X, C2/(C1+1)) + if (ICI.getPredicate() == ICmpInst::ICMP_UGT) { + assert(!C1.isMaxValue() && + "icmp ugt X, UINT_MAX should have been simplified already."); + return new ICmpInst(ICmpInst::ICMP_ULE, X, + ConstantInt::get(X->getType(), C2.udiv(C1 + 1))); + } + // (icmp ult (udiv C2, X), C1) -> (icmp ugt X, C2/C1) + if (ICI.getPredicate() == ICmpInst::ICMP_ULT) { + assert(C1 != 0 && "icmp ult X, 0 should have been simplified already."); + return new ICmpInst(ICmpInst::ICMP_UGT, X, + ConstantInt::get(X->getType(), C2.udiv(C1))); + } + } + // fall-through + case Instruction::SDiv: // Fold: icmp pred ([us]div X, C1), C2 -> range test // Fold this div into the comparison, producing a range check. // Determine, based on the divide type, what the range is being @@ -1895,27 +2240,30 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, } break; case Instruction::Xor: - // For the xor case, we can xor two constants together, eliminating - // the explicit xor. - if (Constant *BOC = dyn_cast<Constant>(BO->getOperand(1))) { - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - ConstantExpr::getXor(RHS, BOC)); - } else if (RHSV == 0) { - // Replace ((xor A, B) != 0) with (A != B) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - BO->getOperand(1)); + if (BO->hasOneUse()) { + if (Constant *BOC = dyn_cast<Constant>(BO->getOperand(1))) { + // For the xor case, we can xor two constants together, eliminating + // the explicit xor. + return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), + ConstantExpr::getXor(RHS, BOC)); + } else if (RHSV == 0) { + // Replace ((xor A, B) != 0) with (A != B) + return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), + BO->getOperand(1)); + } } break; case Instruction::Sub: - // Replace ((sub A, B) != C) with (B != A-C) if A & C are constants. - if (ConstantInt *BOp0C = dyn_cast<ConstantInt>(BO->getOperand(0))) { - if (BO->hasOneUse()) + if (BO->hasOneUse()) { + if (ConstantInt *BOp0C = dyn_cast<ConstantInt>(BO->getOperand(0))) { + // Replace ((sub A, B) != C) with (B != A-C) if A & C are constants. return new ICmpInst(ICI.getPredicate(), BO->getOperand(1), - ConstantExpr::getSub(BOp0C, RHS)); - } else if (RHSV == 0) { - // Replace ((sub A, B) != 0) with (A != B) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - BO->getOperand(1)); + ConstantExpr::getSub(BOp0C, RHS)); + } else if (RHSV == 0) { + // Replace ((sub A, B) != 0) with (A != B) + return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), + BO->getOperand(1)); + } } break; case Instruction::Or: @@ -1924,7 +2272,16 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, if (ConstantInt *BOC = dyn_cast<ConstantInt>(BO->getOperand(1))) { Constant *NotCI = ConstantExpr::getNot(RHS); if (!ConstantExpr::getAnd(BOC, NotCI)->isNullValue()) - return ReplaceInstUsesWith(ICI, Builder->getInt1(isICMP_NE)); + return replaceInstUsesWith(ICI, Builder->getInt1(isICMP_NE)); + + // Comparing if all bits outside of a constant mask are set? + // Replace (X | C) == -1 with (X & ~C) == ~C. + // This removes the -1 constant. + if (BO->hasOneUse() && RHS->isAllOnesValue()) { + Constant *NotBOC = ConstantExpr::getNot(BOC); + Value *And = Builder->CreateAnd(BO->getOperand(0), NotBOC); + return new ICmpInst(ICI.getPredicate(), And, NotBOC); + } } break; @@ -1933,7 +2290,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, // If bits are being compared against that are and'd out, then the // comparison can never succeed! if ((RHSV & ~BOC->getValue()) != 0) - return ReplaceInstUsesWith(ICI, Builder->getInt1(isICMP_NE)); + return replaceInstUsesWith(ICI, Builder->getInt1(isICMP_NE)); // If we have ((X & C) == C), turn it into ((X & C) != 0). if (RHS == BOC && RHSV.isPowerOf2()) @@ -2013,11 +2370,10 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, return nullptr; } -/// visitICmpInstWithCastAndCast - Handle icmp (cast x to y), (cast/cst). -/// We only handle extending casts so far. -/// -Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { - const CastInst *LHSCI = cast<CastInst>(ICI.getOperand(0)); +/// Handle icmp (cast x to y), (cast/cst). We only handle extending casts so +/// far. +Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICmp) { + const CastInst *LHSCI = cast<CastInst>(ICmp.getOperand(0)); Value *LHSCIOp = LHSCI->getOperand(0); Type *SrcTy = LHSCIOp->getType(); Type *DestTy = LHSCI->getType(); @@ -2028,7 +2384,7 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { if (LHSCI->getOpcode() == Instruction::PtrToInt && DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth()) { Value *RHSOp = nullptr; - if (PtrToIntOperator *RHSC = dyn_cast<PtrToIntOperator>(ICI.getOperand(1))) { + if (auto *RHSC = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) { Value *RHSCIOp = RHSC->getOperand(0); if (RHSCIOp->getType()->getPointerAddressSpace() == LHSCIOp->getType()->getPointerAddressSpace()) { @@ -2037,11 +2393,12 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { if (LHSCIOp->getType() != RHSOp->getType()) RHSOp = Builder->CreateBitCast(RHSOp, LHSCIOp->getType()); } - } else if (Constant *RHSC = dyn_cast<Constant>(ICI.getOperand(1))) + } else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) { RHSOp = ConstantExpr::getIntToPtr(RHSC, SrcTy); + } if (RHSOp) - return new ICmpInst(ICI.getPredicate(), LHSCIOp, RHSOp); + return new ICmpInst(ICmp.getPredicate(), LHSCIOp, RHSOp); } // The code below only handles extension cast instructions, so far. @@ -2051,9 +2408,9 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { return nullptr; bool isSignedExt = LHSCI->getOpcode() == Instruction::SExt; - bool isSignedCmp = ICI.isSigned(); + bool isSignedCmp = ICmp.isSigned(); - if (CastInst *CI = dyn_cast<CastInst>(ICI.getOperand(1))) { + if (auto *CI = dyn_cast<CastInst>(ICmp.getOperand(1))) { // Not an extension from the same type? RHSCIOp = CI->getOperand(0); if (RHSCIOp->getType() != LHSCIOp->getType()) @@ -2065,50 +2422,51 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { return nullptr; // Deal with equality cases early. - if (ICI.isEquality()) - return new ICmpInst(ICI.getPredicate(), LHSCIOp, RHSCIOp); + if (ICmp.isEquality()) + return new ICmpInst(ICmp.getPredicate(), LHSCIOp, RHSCIOp); // A signed comparison of sign extended values simplifies into a // signed comparison. if (isSignedCmp && isSignedExt) - return new ICmpInst(ICI.getPredicate(), LHSCIOp, RHSCIOp); + return new ICmpInst(ICmp.getPredicate(), LHSCIOp, RHSCIOp); // The other three cases all fold into an unsigned comparison. - return new ICmpInst(ICI.getUnsignedPredicate(), LHSCIOp, RHSCIOp); + return new ICmpInst(ICmp.getUnsignedPredicate(), LHSCIOp, RHSCIOp); } - // If we aren't dealing with a constant on the RHS, exit early - ConstantInt *CI = dyn_cast<ConstantInt>(ICI.getOperand(1)); - if (!CI) + // If we aren't dealing with a constant on the RHS, exit early. + auto *C = dyn_cast<Constant>(ICmp.getOperand(1)); + if (!C) return nullptr; // Compute the constant that would happen if we truncated to SrcTy then - // reextended to DestTy. - Constant *Res1 = ConstantExpr::getTrunc(CI, SrcTy); - Constant *Res2 = ConstantExpr::getCast(LHSCI->getOpcode(), - Res1, DestTy); + // re-extended to DestTy. + Constant *Res1 = ConstantExpr::getTrunc(C, SrcTy); + Constant *Res2 = ConstantExpr::getCast(LHSCI->getOpcode(), Res1, DestTy); // If the re-extended constant didn't change... - if (Res2 == CI) { + if (Res2 == C) { // Deal with equality cases early. - if (ICI.isEquality()) - return new ICmpInst(ICI.getPredicate(), LHSCIOp, Res1); + if (ICmp.isEquality()) + return new ICmpInst(ICmp.getPredicate(), LHSCIOp, Res1); // A signed comparison of sign extended values simplifies into a // signed comparison. if (isSignedExt && isSignedCmp) - return new ICmpInst(ICI.getPredicate(), LHSCIOp, Res1); + return new ICmpInst(ICmp.getPredicate(), LHSCIOp, Res1); // The other three cases all fold into an unsigned comparison. - return new ICmpInst(ICI.getUnsignedPredicate(), LHSCIOp, Res1); + return new ICmpInst(ICmp.getUnsignedPredicate(), LHSCIOp, Res1); } - // The re-extended constant changed so the constant cannot be represented - // in the shorter type. Consequently, we cannot emit a simple comparison. + // The re-extended constant changed, partly changed (in the case of a vector), + // or could not be determined to be equal (in the case of a constant + // expression), so the constant cannot be represented in the shorter type. + // Consequently, we cannot emit a simple comparison. // All the cases that fold to true or false will have already been handled // by SimplifyICmpInst, so only deal with the tricky case. - if (isSignedCmp || !isSignedExt) + if (isSignedCmp || !isSignedExt || !isa<ConstantInt>(C)) return nullptr; // Evaluate the comparison for LT (we invert for GT below). LE and GE cases @@ -2117,17 +2475,17 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // We're performing an unsigned comp with a sign extended value. // This is true if the input is >= 0. [aka >s -1] Constant *NegOne = Constant::getAllOnesValue(SrcTy); - Value *Result = Builder->CreateICmpSGT(LHSCIOp, NegOne, ICI.getName()); + Value *Result = Builder->CreateICmpSGT(LHSCIOp, NegOne, ICmp.getName()); // Finally, return the value computed. - if (ICI.getPredicate() == ICmpInst::ICMP_ULT) - return ReplaceInstUsesWith(ICI, Result); + if (ICmp.getPredicate() == ICmpInst::ICMP_ULT) + return replaceInstUsesWith(ICmp, Result); - assert(ICI.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!"); + assert(ICmp.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!"); return BinaryOperator::CreateNot(Result); } -/// ProcessUGT_ADDCST_ADD - The caller has matched a pattern of the form: +/// The caller has matched a pattern of the form: /// I = icmp ugt (add (add A, B), CI2), CI1 /// If this is of the form: /// sum = a + b @@ -2207,7 +2565,7 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // The inner add was the result of the narrow add, zero extended to the // wider type. Replace it with the result computed by the intrinsic. - IC.ReplaceInstUsesWith(*OrigAdd, ZExt); + IC.replaceInstUsesWith(*OrigAdd, ZExt); // The original icmp gets replaced with the overflow value. return ExtractValueInst::Create(Call, 1, "sadd.overflow"); @@ -2491,7 +2849,7 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, continue; if (TruncInst *TI = dyn_cast<TruncInst>(U)) { if (TI->getType()->getPrimitiveSizeInBits() == MulWidth) - IC.ReplaceInstUsesWith(*TI, Mul); + IC.replaceInstUsesWith(*TI, Mul); else TI->setOperand(0, Mul); } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) { @@ -2503,7 +2861,7 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, Instruction *Zext = cast<Instruction>(Builder->CreateZExt(ShortAnd, BO->getType())); IC.Worklist.Add(Zext); - IC.ReplaceInstUsesWith(*BO, Zext); + IC.replaceInstUsesWith(*BO, Zext); } else { llvm_unreachable("Unexpected Binary operation"); } @@ -2545,9 +2903,9 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, return ExtractValueInst::Create(Call, 1); } -// DemandedBitsLHSMask - When performing a comparison against a constant, -// it is possible that not all the bits in the LHS are demanded. This helper -// method computes the mask that IS demanded. +/// When performing a comparison against a constant, it is possible that not all +/// the bits in the LHS are demanded. This helper method computes the mask that +/// IS demanded. static APInt DemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth, bool isSignCheck) { if (isSignCheck) @@ -2656,9 +3014,7 @@ bool InstCombiner::dominatesAllUses(const Instruction *DI, return true; } -/// -/// true when the instruction sequence within a block is select-cmp-br. -/// +/// Return true when the instruction sequence within a block is select-cmp-br. static bool isChainSelectCmpBranch(const SelectInst *SI) { const BasicBlock *BB = SI->getParent(); if (!BB) @@ -2672,7 +3028,6 @@ static bool isChainSelectCmpBranch(const SelectInst *SI) { return true; } -/// /// \brief True when a select result is replaced by one of its operands /// in select-icmp sequence. This will eventually result in the elimination /// of the select. @@ -2738,6 +3093,63 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, return false; } +/// If we have an icmp le or icmp ge instruction with a constant operand, turn +/// it into the appropriate icmp lt or icmp gt instruction. This transform +/// allows them to be folded in visitICmpInst. +static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { + ICmpInst::Predicate Pred = I.getPredicate(); + if (Pred != ICmpInst::ICMP_SLE && Pred != ICmpInst::ICMP_SGE && + Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_UGE) + return nullptr; + + Value *Op0 = I.getOperand(0); + Value *Op1 = I.getOperand(1); + auto *Op1C = dyn_cast<Constant>(Op1); + if (!Op1C) + return nullptr; + + // Check if the constant operand can be safely incremented/decremented without + // overflowing/underflowing. For scalars, SimplifyICmpInst has already handled + // the edge cases for us, so we just assert on them. For vectors, we must + // handle the edge cases. + Type *Op1Type = Op1->getType(); + bool IsSigned = I.isSigned(); + bool IsLE = (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_ULE); + auto *CI = dyn_cast<ConstantInt>(Op1C); + if (CI) { + // A <= MAX -> TRUE ; A >= MIN -> TRUE + assert(IsLE ? !CI->isMaxValue(IsSigned) : !CI->isMinValue(IsSigned)); + } else if (Op1Type->isVectorTy()) { + // TODO? If the edge cases for vectors were guaranteed to be handled as they + // are for scalar, we could remove the min/max checks. However, to do that, + // we would have to use insertelement/shufflevector to replace edge values. + unsigned NumElts = Op1Type->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = Op1C->getAggregateElement(i); + if (!Elt) + return nullptr; + + if (isa<UndefValue>(Elt)) + continue; + // Bail out if we can't determine if this constant is min/max or if we + // know that this constant is min/max. + auto *CI = dyn_cast<ConstantInt>(Elt); + if (!CI || (IsLE ? CI->isMaxValue(IsSigned) : CI->isMinValue(IsSigned))) + return nullptr; + } + } else { + // ConstantExpr? + return nullptr; + } + + // Increment or decrement the constant and set the new comparison predicate: + // ULE -> ULT ; UGE -> UGT ; SLE -> SLT ; SGE -> SGT + Constant *OneOrNegOne = ConstantInt::get(Op1Type, IsLE ? 1 : -1, true); + CmpInst::Predicate NewPred = IsLE ? ICmpInst::ICMP_ULT: ICmpInst::ICMP_UGT; + NewPred = IsSigned ? ICmpInst::getSignedPredicate(NewPred) : NewPred; + return new ICmpInst(NewPred, Op0, ConstantExpr::getAdd(Op1C, OneOrNegOne)); +} + Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { bool Changed = false; Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -2748,8 +3160,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { /// complex to least complex. This puts constants before unary operators, /// before binary operators. if (Op0Cplxity < Op1Cplxity || - (Op0Cplxity == Op1Cplxity && - swapMayExposeCSEOpportunities(Op0, Op1))) { + (Op0Cplxity == Op1Cplxity && swapMayExposeCSEOpportunities(Op0, Op1))) { I.swapOperands(); std::swap(Op0, Op1); Changed = true; @@ -2757,12 +3168,11 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AC, &I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // comparing -val or val with non-zero is the same as just comparing val // ie, abs(val) != 0 -> val != 0 - if (I.getPredicate() == ICmpInst::ICMP_NE && match(Op1, m_Zero())) - { + if (I.getPredicate() == ICmpInst::ICMP_NE && match(Op1, m_Zero())) { Value *Cond, *SelectTrue, *SelectFalse; if (match(Op0, m_Select(m_Value(Cond), m_Value(SelectTrue), m_Value(SelectFalse)))) { @@ -2780,47 +3190,50 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Type *Ty = Op0->getType(); // icmp's with boolean values can always be turned into bitwise operations - if (Ty->isIntegerTy(1)) { + if (Ty->getScalarType()->isIntegerTy(1)) { switch (I.getPredicate()) { default: llvm_unreachable("Invalid icmp instruction!"); - case ICmpInst::ICMP_EQ: { // icmp eq i1 A, B -> ~(A^B) - Value *Xor = Builder->CreateXor(Op0, Op1, I.getName()+"tmp"); + case ICmpInst::ICMP_EQ: { // icmp eq i1 A, B -> ~(A^B) + Value *Xor = Builder->CreateXor(Op0, Op1, I.getName() + "tmp"); return BinaryOperator::CreateNot(Xor); } - case ICmpInst::ICMP_NE: // icmp eq i1 A, B -> A^B + case ICmpInst::ICMP_NE: // icmp ne i1 A, B -> A^B return BinaryOperator::CreateXor(Op0, Op1); case ICmpInst::ICMP_UGT: std::swap(Op0, Op1); // Change icmp ugt -> icmp ult // FALL THROUGH - case ICmpInst::ICMP_ULT:{ // icmp ult i1 A, B -> ~A & B - Value *Not = Builder->CreateNot(Op0, I.getName()+"tmp"); + case ICmpInst::ICMP_ULT:{ // icmp ult i1 A, B -> ~A & B + Value *Not = Builder->CreateNot(Op0, I.getName() + "tmp"); return BinaryOperator::CreateAnd(Not, Op1); } case ICmpInst::ICMP_SGT: std::swap(Op0, Op1); // Change icmp sgt -> icmp slt // FALL THROUGH case ICmpInst::ICMP_SLT: { // icmp slt i1 A, B -> A & ~B - Value *Not = Builder->CreateNot(Op1, I.getName()+"tmp"); + Value *Not = Builder->CreateNot(Op1, I.getName() + "tmp"); return BinaryOperator::CreateAnd(Not, Op0); } case ICmpInst::ICMP_UGE: std::swap(Op0, Op1); // Change icmp uge -> icmp ule // FALL THROUGH - case ICmpInst::ICMP_ULE: { // icmp ule i1 A, B -> ~A | B - Value *Not = Builder->CreateNot(Op0, I.getName()+"tmp"); + case ICmpInst::ICMP_ULE: { // icmp ule i1 A, B -> ~A | B + Value *Not = Builder->CreateNot(Op0, I.getName() + "tmp"); return BinaryOperator::CreateOr(Not, Op1); } case ICmpInst::ICMP_SGE: std::swap(Op0, Op1); // Change icmp sge -> icmp sle // FALL THROUGH - case ICmpInst::ICMP_SLE: { // icmp sle i1 A, B -> A | ~B - Value *Not = Builder->CreateNot(Op1, I.getName()+"tmp"); + case ICmpInst::ICMP_SLE: { // icmp sle i1 A, B -> A | ~B + Value *Not = Builder->CreateNot(Op1, I.getName() + "tmp"); return BinaryOperator::CreateOr(Not, Op0); } } } + if (ICmpInst *NewICmp = canonicalizeCmpWithConstant(I)) + return NewICmp; + unsigned BitWidth = 0; if (Ty->isIntOrIntVectorTy()) BitWidth = Ty->getScalarSizeInBits(); @@ -2853,6 +3266,19 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return Res; } + // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) + if (CI->isZero() && I.getPredicate() == ICmpInst::ICMP_SGT) + if (auto *SI = dyn_cast<SelectInst>(Op0)) { + SelectPatternResult SPR = matchSelectPattern(SI, A, B); + if (SPR.Flavor == SPF_SMIN) { + if (isKnownPositive(A, DL)) + return new ICmpInst(I.getPredicate(), B, CI); + if (isKnownPositive(B, DL)) + return new ICmpInst(I.getPredicate(), A, CI); + } + } + + // The following transforms are only 'worth it' if the only user of the // subtraction is the icmp. if (Op0->hasOneUse()) { @@ -2882,30 +3308,6 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return new ICmpInst(ICmpInst::ICMP_SLE, A, B); } - // If we have an icmp le or icmp ge instruction, turn it into the - // appropriate icmp lt or icmp gt instruction. This allows us to rely on - // them being folded in the code below. The SimplifyICmpInst code has - // already handled the edge cases for us, so we just assert on them. - switch (I.getPredicate()) { - default: break; - case ICmpInst::ICMP_ULE: - assert(!CI->isMaxValue(false)); // A <=u MAX -> TRUE - return new ICmpInst(ICmpInst::ICMP_ULT, Op0, - Builder->getInt(CI->getValue()+1)); - case ICmpInst::ICMP_SLE: - assert(!CI->isMaxValue(true)); // A <=s MAX -> TRUE - return new ICmpInst(ICmpInst::ICMP_SLT, Op0, - Builder->getInt(CI->getValue()+1)); - case ICmpInst::ICMP_UGE: - assert(!CI->isMinValue(false)); // A >=u MIN -> TRUE - return new ICmpInst(ICmpInst::ICMP_UGT, Op0, - Builder->getInt(CI->getValue()-1)); - case ICmpInst::ICMP_SGE: - assert(!CI->isMinValue(true)); // A >=s MIN -> TRUE - return new ICmpInst(ICmpInst::ICMP_SGT, Op0, - Builder->getInt(CI->getValue()-1)); - } - if (I.isEquality()) { ConstantInt *CI2; if (match(Op0, m_AShr(m_ConstantInt(CI2), m_Value(A))) || @@ -2925,6 +3327,42 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // bits, if it is a sign bit comparison, it only demands the sign bit. bool UnusedBit; isSignBit = isSignBitCheck(I.getPredicate(), CI, UnusedBit); + + // Canonicalize icmp instructions based on dominating conditions. + BasicBlock *Parent = I.getParent(); + BasicBlock *Dom = Parent->getSinglePredecessor(); + auto *BI = Dom ? dyn_cast<BranchInst>(Dom->getTerminator()) : nullptr; + ICmpInst::Predicate Pred; + BasicBlock *TrueBB, *FalseBB; + ConstantInt *CI2; + if (BI && match(BI, m_Br(m_ICmp(Pred, m_Specific(Op0), m_ConstantInt(CI2)), + TrueBB, FalseBB)) && + TrueBB != FalseBB) { + ConstantRange CR = ConstantRange::makeAllowedICmpRegion(I.getPredicate(), + CI->getValue()); + ConstantRange DominatingCR = + (Parent == TrueBB) + ? ConstantRange::makeExactICmpRegion(Pred, CI2->getValue()) + : ConstantRange::makeExactICmpRegion( + CmpInst::getInversePredicate(Pred), CI2->getValue()); + ConstantRange Intersection = DominatingCR.intersectWith(CR); + ConstantRange Difference = DominatingCR.difference(CR); + if (Intersection.isEmptySet()) + return replaceInstUsesWith(I, Builder->getFalse()); + if (Difference.isEmptySet()) + return replaceInstUsesWith(I, Builder->getTrue()); + // Canonicalizing a sign bit comparison that gets used in a branch, + // pessimizes codegen by generating branch on zero instruction instead + // of a test and branch. So we avoid canonicalizing in such situations + // because test and branch instruction has better branch displacement + // than compare and branch instruction. + if (!isBranchOnSignBitCheck(I, isSignBit) && !I.isEquality()) { + if (auto *AI = Intersection.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Builder->getInt(*AI)); + if (auto *AD = Difference.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Builder->getInt(*AD)); + } + } } // See if we can fold the comparison based on range information we can get @@ -2975,7 +3413,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { default: llvm_unreachable("Unknown icmp opcode!"); case ICmpInst::ICMP_EQ: { if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); // If all bits are known zero except for one, then we know at most one // bit is set. If the comparison is against zero, then this is a check @@ -3019,7 +3457,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } case ICmpInst::ICMP_NE: { if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); // If all bits are known zero except for one, then we know at most one // bit is set. If the comparison is against zero, then this is a check @@ -3063,9 +3501,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } case ICmpInst::ICMP_ULT: if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { @@ -3081,9 +3519,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { break; case ICmpInst::ICMP_UGT: if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); @@ -3100,9 +3538,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { break; case ICmpInst::ICMP_SLT: if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { @@ -3113,9 +3551,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { break; case ICmpInst::ICMP_SGT: if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); @@ -3128,30 +3566,30 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { case ICmpInst::ICMP_SGE: assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!"); if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); break; case ICmpInst::ICMP_SLE: assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!"); if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); break; case ICmpInst::ICMP_UGE: assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!"); if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); break; case ICmpInst::ICMP_ULE: assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!"); if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); break; } @@ -3179,12 +3617,22 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // See if we are doing a comparison between a constant and an instruction that // can be folded into the comparison. if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { + Value *A = nullptr, *B = nullptr; // Since the RHS is a ConstantInt (CI), if the left hand side is an // instruction, see if that instruction also has constants so that the // instruction can be folded into the icmp if (Instruction *LHSI = dyn_cast<Instruction>(Op0)) if (Instruction *Res = visitICmpInstWithInstAndIntCst(I, LHSI, CI)) return Res; + + // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) + if (I.isEquality() && CI->isZero() && + match(Op0, m_UDiv(m_Value(A), m_Value(B)))) { + ICmpInst::Predicate Pred = I.getPredicate() == ICmpInst::ICMP_EQ + ? ICmpInst::ICMP_UGT + : ICmpInst::ICMP_ULE; + return new ICmpInst(Pred, B, A); + } } // Handle icmp with constant (but not simple integer constant) RHS @@ -3354,10 +3802,14 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // Analyze the case when either Op0 or Op1 is an add instruction. // Op0 = A + B (or A and B are null); Op1 = C + D (or C and D are null). Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; - if (BO0 && BO0->getOpcode() == Instruction::Add) - A = BO0->getOperand(0), B = BO0->getOperand(1); - if (BO1 && BO1->getOpcode() == Instruction::Add) - C = BO1->getOperand(0), D = BO1->getOperand(1); + if (BO0 && BO0->getOpcode() == Instruction::Add) { + A = BO0->getOperand(0); + B = BO0->getOperand(1); + } + if (BO1 && BO1->getOpcode() == Instruction::Add) { + C = BO1->getOperand(0); + D = BO1->getOperand(1); + } // icmp (X+cst) < 0 --> X < -cst if (NoOp0WrapProblem && ICmpInst::isSigned(Pred) && match(Op1, m_Zero())) @@ -3474,11 +3926,18 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // Analyze the case when either Op0 or Op1 is a sub instruction. // Op0 = A - B (or A and B are null); Op1 = C - D (or C and D are null). - A = nullptr; B = nullptr; C = nullptr; D = nullptr; - if (BO0 && BO0->getOpcode() == Instruction::Sub) - A = BO0->getOperand(0), B = BO0->getOperand(1); - if (BO1 && BO1->getOpcode() == Instruction::Sub) - C = BO1->getOperand(0), D = BO1->getOperand(1); + A = nullptr; + B = nullptr; + C = nullptr; + D = nullptr; + if (BO0 && BO0->getOpcode() == Instruction::Sub) { + A = BO0->getOperand(0); + B = BO0->getOperand(1); + } + if (BO1 && BO1->getOpcode() == Instruction::Sub) { + C = BO1->getOperand(0); + D = BO1->getOperand(1); + } // icmp (X-Y), X -> icmp 0, Y for equalities or if there is no overflow. if (A == Op1 && NoOp0WrapProblem) @@ -3525,9 +3984,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { switch (SRem == BO0 ? ICmpInst::getSwappedPredicate(Pred) : Pred) { default: break; case ICmpInst::ICMP_EQ: - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); case ICmpInst::ICMP_NE: - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: return new ICmpInst(ICmpInst::ICMP_SGT, SRem->getOperand(1), @@ -3654,8 +4113,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Constant *Overflow; if (OptimizeOverflowCheck(OCF_UNSIGNED_ADD, A, B, *AddI, Result, Overflow)) { - ReplaceInstUsesWith(*AddI, Result); - return ReplaceInstUsesWith(I, Overflow); + replaceInstUsesWith(*AddI, Result); + return replaceInstUsesWith(I, Overflow); } } @@ -3834,7 +4293,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return Changed ? &I : nullptr; } -/// FoldFCmp_IntToFP_Cst - Fold fcmp ([us]itofp x, cst) if possible. +/// Fold fcmp ([us]itofp x, cst) if possible. Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, Instruction *LHSI, Constant *RHSC) { @@ -3864,10 +4323,10 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, RHSRoundInt.roundToIntegral(APFloat::rmNearestTiesToEven); if (RHS.compare(RHSRoundInt) != APFloat::cmpEqual) { if (P == FCmpInst::FCMP_OEQ || P == FCmpInst::FCMP_UEQ) - return ReplaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder->getFalse()); assert(P == FCmpInst::FCMP_ONE || P == FCmpInst::FCMP_UNE); - return ReplaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder->getTrue()); } } @@ -3933,9 +4392,9 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, Pred = ICmpInst::ICMP_NE; break; case FCmpInst::FCMP_ORD: - return ReplaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder->getTrue()); case FCmpInst::FCMP_UNO: - return ReplaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder->getFalse()); } // Now we know that the APFloat is a normal number, zero or inf. @@ -3953,8 +4412,8 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, if (SMax.compare(RHS) == APFloat::cmpLessThan) { // smax < 13123.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) - return ReplaceInstUsesWith(I, Builder->getTrue()); - return ReplaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder->getFalse()); } } else { // If the RHS value is > UnsignedMax, fold the comparison. This handles @@ -3965,8 +4424,8 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, if (UMax.compare(RHS) == APFloat::cmpLessThan) { // umax < 13123.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) - return ReplaceInstUsesWith(I, Builder->getTrue()); - return ReplaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder->getFalse()); } } @@ -3978,8 +4437,8 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, if (SMin.compare(RHS) == APFloat::cmpGreaterThan) { // smin > 12312.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) - return ReplaceInstUsesWith(I, Builder->getTrue()); - return ReplaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder->getFalse()); } } else { // See if the RHS value is < UnsignedMin. @@ -3989,8 +4448,8 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, if (SMin.compare(RHS) == APFloat::cmpGreaterThan) { // umin > 12312.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) - return ReplaceInstUsesWith(I, Builder->getTrue()); - return ReplaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder->getFalse()); } } @@ -4012,14 +4471,14 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, switch (Pred) { default: llvm_unreachable("Unexpected integer comparison!"); case ICmpInst::ICMP_NE: // (float)int != 4.4 --> true - return ReplaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder->getTrue()); case ICmpInst::ICMP_EQ: // (float)int == 4.4 --> false - return ReplaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder->getFalse()); case ICmpInst::ICMP_ULE: // (float)int <= 4.4 --> int <= 4 // (float)int <= -4.4 --> false if (RHS.isNegative()) - return ReplaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder->getFalse()); break; case ICmpInst::ICMP_SLE: // (float)int <= 4.4 --> int <= 4 @@ -4031,7 +4490,7 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, // (float)int < -4.4 --> false // (float)int < 4.4 --> int <= 4 if (RHS.isNegative()) - return ReplaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder->getFalse()); Pred = ICmpInst::ICMP_ULE; break; case ICmpInst::ICMP_SLT: @@ -4044,7 +4503,7 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, // (float)int > 4.4 --> int > 4 // (float)int > -4.4 --> true if (RHS.isNegative()) - return ReplaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder->getTrue()); break; case ICmpInst::ICMP_SGT: // (float)int > 4.4 --> int > 4 @@ -4056,7 +4515,7 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, // (float)int >= -4.4 --> true // (float)int >= 4.4 --> int > 4 if (RHS.isNegative()) - return ReplaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder->getTrue()); Pred = ICmpInst::ICMP_UGT; break; case ICmpInst::ICMP_SGE: @@ -4089,7 +4548,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC, &I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // Simplify 'fcmp pred X, X' if (Op0 == Op1) { @@ -4208,39 +4667,33 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { break; CallInst *CI = cast<CallInst>(LHSI); - const Function *F = CI->getCalledFunction(); - if (!F) + Intrinsic::ID IID = getIntrinsicForCallSite(CI, TLI); + if (IID != Intrinsic::fabs) break; // Various optimization for fabs compared with zero. - LibFunc::Func Func; - if (F->getIntrinsicID() == Intrinsic::fabs || - (TLI->getLibFunc(F->getName(), Func) && TLI->has(Func) && - (Func == LibFunc::fabs || Func == LibFunc::fabsf || - Func == LibFunc::fabsl))) { - switch (I.getPredicate()) { - default: - break; - // fabs(x) < 0 --> false - case FCmpInst::FCMP_OLT: - return ReplaceInstUsesWith(I, Builder->getFalse()); - // fabs(x) > 0 --> x != 0 - case FCmpInst::FCMP_OGT: - return new FCmpInst(FCmpInst::FCMP_ONE, CI->getArgOperand(0), RHSC); - // fabs(x) <= 0 --> x == 0 - case FCmpInst::FCMP_OLE: - return new FCmpInst(FCmpInst::FCMP_OEQ, CI->getArgOperand(0), RHSC); - // fabs(x) >= 0 --> !isnan(x) - case FCmpInst::FCMP_OGE: - return new FCmpInst(FCmpInst::FCMP_ORD, CI->getArgOperand(0), RHSC); - // fabs(x) == 0 --> x == 0 - // fabs(x) != 0 --> x != 0 - case FCmpInst::FCMP_OEQ: - case FCmpInst::FCMP_UEQ: - case FCmpInst::FCMP_ONE: - case FCmpInst::FCMP_UNE: - return new FCmpInst(I.getPredicate(), CI->getArgOperand(0), RHSC); - } + switch (I.getPredicate()) { + default: + break; + // fabs(x) < 0 --> false + case FCmpInst::FCMP_OLT: + llvm_unreachable("handled by SimplifyFCmpInst"); + // fabs(x) > 0 --> x != 0 + case FCmpInst::FCMP_OGT: + return new FCmpInst(FCmpInst::FCMP_ONE, CI->getArgOperand(0), RHSC); + // fabs(x) <= 0 --> x == 0 + case FCmpInst::FCMP_OLE: + return new FCmpInst(FCmpInst::FCMP_OEQ, CI->getArgOperand(0), RHSC); + // fabs(x) >= 0 --> !isnan(x) + case FCmpInst::FCMP_OGE: + return new FCmpInst(FCmpInst::FCMP_ORD, CI->getArgOperand(0), RHSC); + // fabs(x) == 0 --> x == 0 + // fabs(x) != 0 --> x != 0 + case FCmpInst::FCMP_OEQ: + case FCmpInst::FCMP_UEQ: + case FCmpInst::FCMP_ONE: + case FCmpInst::FCMP_UNE: + return new FCmpInst(I.getPredicate(), CI->getArgOperand(0), RHSC); } } } diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index e4e506509d392..aa421ff594fb8 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -138,7 +138,7 @@ IntrinsicIDToOverflowCheckFlavor(unsigned ID) { /// \brief An IRBuilder inserter that adds new instructions to the instcombine /// worklist. class LLVM_LIBRARY_VISIBILITY InstCombineIRInserter - : public IRBuilderDefaultInserter<true> { + : public IRBuilderDefaultInserter { InstCombineWorklist &Worklist; AssumptionCache *AC; @@ -148,7 +148,7 @@ public: void InsertHelper(Instruction *I, const Twine &Name, BasicBlock *BB, BasicBlock::iterator InsertPt) const { - IRBuilderDefaultInserter<true>::InsertHelper(I, Name, BB, InsertPt); + IRBuilderDefaultInserter::InsertHelper(I, Name, BB, InsertPt); Worklist.Add(I); using namespace llvm::PatternMatch; @@ -171,12 +171,14 @@ public: /// \brief An IRBuilder that automatically inserts new instructions into the /// worklist. - typedef IRBuilder<true, TargetFolder, InstCombineIRInserter> BuilderTy; + typedef IRBuilder<TargetFolder, InstCombineIRInserter> BuilderTy; BuilderTy *Builder; private: // Mode in which we are running the combiner. const bool MinimizeSize; + /// Enable combines that trigger rarely but are costly in compiletime. + const bool ExpensiveCombines; AliasAnalysis *AA; @@ -195,11 +197,12 @@ private: public: InstCombiner(InstCombineWorklist &Worklist, BuilderTy *Builder, - bool MinimizeSize, AliasAnalysis *AA, + bool MinimizeSize, bool ExpensiveCombines, AliasAnalysis *AA, AssumptionCache *AC, TargetLibraryInfo *TLI, DominatorTree *DT, const DataLayout &DL, LoopInfo *LI) : Worklist(Worklist), Builder(Builder), MinimizeSize(MinimizeSize), - AA(AA), AC(AC), TLI(TLI), DT(DT), DL(DL), LI(LI), MadeIRChange(false) {} + ExpensiveCombines(ExpensiveCombines), AA(AA), AC(AC), TLI(TLI), DT(DT), + DL(DL), LI(LI), MadeIRChange(false) {} /// \brief Run the combiner over the entire worklist until it is empty. /// @@ -327,6 +330,8 @@ public: Instruction *visitShuffleVectorInst(ShuffleVectorInst &SVI); Instruction *visitExtractValueInst(ExtractValueInst &EV); Instruction *visitLandingPadInst(LandingPadInst &LI); + Instruction *visitVAStartInst(VAStartInst &I); + Instruction *visitVACopyInst(VACopyInst &I); // visitInstruction - Specify what to return for unhandled instructions... Instruction *visitInstruction(Instruction &I) { return nullptr; } @@ -390,6 +395,7 @@ private: Value *EmitGEPOffset(User *GEP); Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN); Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask); + Instruction *foldCastedBitwiseLogic(BinaryOperator &I); public: /// \brief Inserts an instruction \p New before instruction \p Old @@ -417,7 +423,7 @@ public: /// replaceable with another preexisting expression. Here we add all uses of /// I to the worklist, replace all uses of I with the new value, then return /// I, so that the inst combiner will know that I was modified. - Instruction *ReplaceInstUsesWith(Instruction &I, Value *V) { + Instruction *replaceInstUsesWith(Instruction &I, Value *V) { // If there are no uses to replace, then we return nullptr to indicate that // no changes were made to the program. if (I.use_empty()) return nullptr; @@ -451,16 +457,16 @@ public: /// When dealing with an instruction that has side effects or produces a void /// value, we can't rely on DCE to delete the instruction. Instead, visit /// methods should return the value returned by this function. - Instruction *EraseInstFromFunction(Instruction &I) { + Instruction *eraseInstFromFunction(Instruction &I) { DEBUG(dbgs() << "IC: ERASE " << I << '\n'); assert(I.use_empty() && "Cannot erase instruction that is used!"); // Make sure that we reprocess all operands now that we reduced their // use counts. if (I.getNumOperands() < 8) { - for (User::op_iterator i = I.op_begin(), e = I.op_end(); i != e; ++i) - if (Instruction *Op = dyn_cast<Instruction>(*i)) - Worklist.Add(Op); + for (Use &Operand : I.operands()) + if (auto *Inst = dyn_cast<Instruction>(Operand)) + Worklist.Add(Inst); } Worklist.Remove(&I); I.eraseFromParent(); @@ -515,12 +521,12 @@ private: Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt &KnownZero, APInt &KnownOne, unsigned Depth, Instruction *CxtI); - bool SimplifyDemandedBits(Use &U, APInt DemandedMask, APInt &KnownZero, + bool SimplifyDemandedBits(Use &U, const APInt &DemandedMask, APInt &KnownZero, APInt &KnownOne, unsigned Depth = 0); /// Helper routine of SimplifyDemandedUseBits. It tries to simplify demanded /// bit for "r1 = shr x, c1; r2 = shl r1, c2" instruction sequence. Value *SimplifyShrShlDemandedBits(Instruction *Lsr, Instruction *Sftl, - APInt DemandedMask, APInt &KnownZero, + const APInt &DemandedMask, APInt &KnownZero, APInt &KnownOne); /// \brief Tries to simplify operands to an integer instruction based on its @@ -556,7 +562,7 @@ private: Value *InsertRangeTest(Value *V, Constant *Lo, Constant *Hi, bool isSigned, bool Inside); Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI); - Instruction *MatchBSwapOrBitReverse(BinaryOperator &I); + Instruction *MatchBSwap(BinaryOperator &I); bool SimplifyStoreAtEndOfBlock(StoreInst &SI); Instruction *SimplifyMemTransfer(MemIntrinsic *MI); Instruction *SimplifyMemSet(MemSetInst *MI); diff --git a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index dd2889de405e0..d312983ed51b0 100644 --- a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -205,11 +205,11 @@ static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) { // Now make everything use the getelementptr instead of the original // allocation. - return IC.ReplaceInstUsesWith(AI, GEP); + return IC.replaceInstUsesWith(AI, GEP); } if (isa<UndefValue>(AI.getArraySize())) - return IC.ReplaceInstUsesWith(AI, Constant::getNullValue(AI.getType())); + return IC.replaceInstUsesWith(AI, Constant::getNullValue(AI.getType())); // Ensure that the alloca array size argument has type intptr_t, so that // any casting is exposed early. @@ -271,7 +271,7 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { EntryAI->setAlignment(MaxAlign); if (AI.getType() != EntryAI->getType()) return new BitCastInst(EntryAI, AI.getType()); - return ReplaceInstUsesWith(AI, EntryAI); + return replaceInstUsesWith(AI, EntryAI); } } } @@ -291,12 +291,12 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n'); DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); for (unsigned i = 0, e = ToDelete.size(); i != e; ++i) - EraseInstFromFunction(*ToDelete[i]); + eraseInstFromFunction(*ToDelete[i]); Constant *TheSrc = cast<Constant>(Copy->getSource()); Constant *Cast = ConstantExpr::getPointerBitCastOrAddrSpaceCast(TheSrc, AI.getType()); - Instruction *NewI = ReplaceInstUsesWith(AI, Cast); - EraseInstFromFunction(*Copy); + Instruction *NewI = replaceInstUsesWith(AI, Cast); + eraseInstFromFunction(*Copy); ++NumGlobalCopies; return NewI; } @@ -326,7 +326,8 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT LoadInst *NewLoad = IC.Builder->CreateAlignedLoad( IC.Builder->CreateBitCast(Ptr, NewTy->getPointerTo(AS)), - LI.getAlignment(), LI.getName() + Suffix); + LI.getAlignment(), LI.isVolatile(), LI.getName() + Suffix); + NewLoad->setAtomic(LI.getOrdering(), LI.getSynchScope()); MDBuilder MDB(NewLoad->getContext()); for (const auto &MDPair : MD) { unsigned ID = MDPair.first; @@ -398,7 +399,8 @@ static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value StoreInst *NewStore = IC.Builder->CreateAlignedStore( V, IC.Builder->CreateBitCast(Ptr, V->getType()->getPointerTo(AS)), - SI.getAlignment()); + SI.getAlignment(), SI.isVolatile()); + NewStore->setAtomic(SI.getOrdering(), SI.getSynchScope()); for (const auto &MDPair : MD) { unsigned ID = MDPair.first; MDNode *N = MDPair.second; @@ -438,7 +440,7 @@ static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value return NewStore; } -/// \brief Combine loads to match the type of value their uses after looking +/// \brief Combine loads to match the type of their uses' value after looking /// through intervening bitcasts. /// /// The core idea here is that if the result of a load is used in an operation, @@ -456,9 +458,9 @@ static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value /// later. However, it is risky in case some backend or other part of LLVM is /// relying on the exact type loaded to select appropriate atomic operations. static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { - // FIXME: We could probably with some care handle both volatile and atomic - // loads here but it isn't clear that this is important. - if (!LI.isSimple()) + // FIXME: We could probably with some care handle both volatile and ordered + // atomic loads here but it isn't clear that this is important. + if (!LI.isUnordered()) return nullptr; if (LI.use_empty()) @@ -486,7 +488,7 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { auto *SI = cast<StoreInst>(*UI++); IC.Builder->SetInsertPoint(SI); combineStoreToNewValue(IC, *SI, NewLoad); - IC.EraseInstFromFunction(*SI); + IC.eraseInstFromFunction(*SI); } assert(LI.use_empty() && "Failed to remove all users of the load!"); // Return the old load so the combiner can delete it safely. @@ -503,7 +505,7 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { if (CI->isNoopCast(DL)) { LoadInst *NewLoad = combineLoadToNewType(IC, LI, CI->getDestTy()); CI->replaceAllUsesWith(NewLoad); - IC.EraseInstFromFunction(*CI); + IC.eraseInstFromFunction(*CI); return &LI; } } @@ -523,16 +525,17 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) { if (!T->isAggregateType()) return nullptr; + StringRef Name = LI.getName(); assert(LI.getAlignment() && "Alignment must be set at this point"); if (auto *ST = dyn_cast<StructType>(T)) { // If the struct only have one element, we unpack. - unsigned Count = ST->getNumElements(); - if (Count == 1) { + auto NumElements = ST->getNumElements(); + if (NumElements == 1) { LoadInst *NewLoad = combineLoadToNewType(IC, LI, ST->getTypeAtIndex(0U), ".unpack"); - return IC.ReplaceInstUsesWith(LI, IC.Builder->CreateInsertValue( - UndefValue::get(T), NewLoad, 0, LI.getName())); + return IC.replaceInstUsesWith(LI, IC.Builder->CreateInsertValue( + UndefValue::get(T), NewLoad, 0, Name)); } // We don't want to break loads with padding here as we'd loose @@ -542,38 +545,67 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) { if (SL->hasPadding()) return nullptr; - auto Name = LI.getName(); - SmallString<16> LoadName = Name; - LoadName += ".unpack"; - SmallString<16> EltName = Name; - EltName += ".elt"; + auto Align = LI.getAlignment(); + if (!Align) + Align = DL.getABITypeAlignment(ST); + auto *Addr = LI.getPointerOperand(); - Value *V = UndefValue::get(T); - auto *IdxType = Type::getInt32Ty(ST->getContext()); + auto *IdxType = Type::getInt32Ty(T->getContext()); auto *Zero = ConstantInt::get(IdxType, 0); - for (unsigned i = 0; i < Count; i++) { + + Value *V = UndefValue::get(T); + for (unsigned i = 0; i < NumElements; i++) { Value *Indices[2] = { Zero, ConstantInt::get(IdxType, i), }; - auto *Ptr = IC.Builder->CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), EltName); - auto *L = IC.Builder->CreateAlignedLoad(Ptr, LI.getAlignment(), - LoadName); + auto *Ptr = IC.Builder->CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), + Name + ".elt"); + auto EltAlign = MinAlign(Align, SL->getElementOffset(i)); + auto *L = IC.Builder->CreateAlignedLoad(Ptr, EltAlign, Name + ".unpack"); V = IC.Builder->CreateInsertValue(V, L, i); } V->setName(Name); - return IC.ReplaceInstUsesWith(LI, V); + return IC.replaceInstUsesWith(LI, V); } if (auto *AT = dyn_cast<ArrayType>(T)) { - // If the array only have one element, we unpack. - if (AT->getNumElements() == 1) { - LoadInst *NewLoad = combineLoadToNewType(IC, LI, AT->getElementType(), - ".unpack"); - return IC.ReplaceInstUsesWith(LI, IC.Builder->CreateInsertValue( - UndefValue::get(T), NewLoad, 0, LI.getName())); + auto *ET = AT->getElementType(); + auto NumElements = AT->getNumElements(); + if (NumElements == 1) { + LoadInst *NewLoad = combineLoadToNewType(IC, LI, ET, ".unpack"); + return IC.replaceInstUsesWith(LI, IC.Builder->CreateInsertValue( + UndefValue::get(T), NewLoad, 0, Name)); } + + const DataLayout &DL = IC.getDataLayout(); + auto EltSize = DL.getTypeAllocSize(ET); + auto Align = LI.getAlignment(); + if (!Align) + Align = DL.getABITypeAlignment(T); + + auto *Addr = LI.getPointerOperand(); + auto *IdxType = Type::getInt64Ty(T->getContext()); + auto *Zero = ConstantInt::get(IdxType, 0); + + Value *V = UndefValue::get(T); + uint64_t Offset = 0; + for (uint64_t i = 0; i < NumElements; i++) { + Value *Indices[2] = { + Zero, + ConstantInt::get(IdxType, i), + }; + auto *Ptr = IC.Builder->CreateInBoundsGEP(AT, Addr, makeArrayRef(Indices), + Name + ".elt"); + auto *L = IC.Builder->CreateAlignedLoad(Ptr, MinAlign(Align, Offset), + Name + ".unpack"); + V = IC.Builder->CreateInsertValue(V, L, i); + Offset += EltSize; + } + + V->setName(Name); + return IC.replaceInstUsesWith(LI, V); } return nullptr; @@ -610,7 +642,7 @@ static bool isObjectSizeLessThanOrEq(Value *V, uint64_t MaxSize, } if (GlobalAlias *GA = dyn_cast<GlobalAlias>(P)) { - if (GA->mayBeOverridden()) + if (GA->isInterposable()) return false; Worklist.push_back(GA->getAliasee()); continue; @@ -638,7 +670,7 @@ static bool isObjectSizeLessThanOrEq(Value *V, uint64_t MaxSize, if (!GV->hasDefinitiveInitializer() || !GV->isConstant()) return false; - uint64_t InitSize = DL.getTypeAllocSize(GV->getType()->getElementType()); + uint64_t InitSize = DL.getTypeAllocSize(GV->getValueType()); if (InitSize > MaxSize) return false; continue; @@ -695,10 +727,8 @@ static bool canReplaceGEPIdxWithZero(InstCombiner &IC, GetElementPtrInst *GEPI, return false; SmallVector<Value *, 4> Ops(GEPI->idx_begin(), GEPI->idx_begin() + Idx); - Type *AllocTy = GetElementPtrInst::getIndexedType( - cast<PointerType>(GEPI->getOperand(0)->getType()->getScalarType()) - ->getElementType(), - Ops); + Type *AllocTy = + GetElementPtrInst::getIndexedType(GEPI->getSourceElementType(), Ops); if (!AllocTy || !AllocTy->isSized()) return false; const DataLayout &DL = IC.getDataLayout(); @@ -781,10 +811,6 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { return &LI; } - // None of the following transforms are legal for volatile/atomic loads. - // FIXME: Some of it is okay for atomic loads; needs refactoring. - if (!LI.isSimple()) return nullptr; - if (Instruction *Res = unpackLoadToAggregate(*this, LI)) return Res; @@ -793,10 +819,12 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { // separated by a few arithmetic operations. BasicBlock::iterator BBI(LI); AAMDNodes AATags; + bool IsLoadCSE = false; if (Value *AvailableVal = - FindAvailableLoadedValue(Op, LI.getParent(), BBI, - DefMaxInstsToScan, AA, &AATags)) { - if (LoadInst *NLI = dyn_cast<LoadInst>(AvailableVal)) { + FindAvailableLoadedValue(&LI, LI.getParent(), BBI, + DefMaxInstsToScan, AA, &AATags, &IsLoadCSE)) { + if (IsLoadCSE) { + LoadInst *NLI = cast<LoadInst>(AvailableVal); unsigned KnownIDs[] = { LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, LLVMContext::MD_noalias, LLVMContext::MD_range, @@ -807,11 +835,15 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { combineMetadata(NLI, &LI, KnownIDs); }; - return ReplaceInstUsesWith( + return replaceInstUsesWith( LI, Builder->CreateBitOrPointerCast(AvailableVal, LI.getType(), LI.getName() + ".cast")); } + // None of the following transforms are legal for volatile/ordered atomic + // loads. Most of them do apply for unordered atomics. + if (!LI.isUnordered()) return nullptr; + // load(gep null, ...) -> unreachable if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Op)) { const Value *GEPI0 = GEPI->getOperand(0); @@ -823,7 +855,7 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { // CFG. new StoreInst(UndefValue::get(LI.getType()), Constant::getNullValue(Op->getType()), &LI); - return ReplaceInstUsesWith(LI, UndefValue::get(LI.getType())); + return replaceInstUsesWith(LI, UndefValue::get(LI.getType())); } } @@ -836,7 +868,7 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { // unreachable instruction directly because we cannot modify the CFG. new StoreInst(UndefValue::get(LI.getType()), Constant::getNullValue(Op->getType()), &LI); - return ReplaceInstUsesWith(LI, UndefValue::get(LI.getType())); + return replaceInstUsesWith(LI, UndefValue::get(LI.getType())); } if (Op->hasOneUse()) { @@ -853,14 +885,17 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { if (SelectInst *SI = dyn_cast<SelectInst>(Op)) { // load (select (Cond, &V1, &V2)) --> select(Cond, load &V1, load &V2). unsigned Align = LI.getAlignment(); - if (isSafeToLoadUnconditionally(SI->getOperand(1), SI, Align) && - isSafeToLoadUnconditionally(SI->getOperand(2), SI, Align)) { + if (isSafeToLoadUnconditionally(SI->getOperand(1), Align, DL, SI) && + isSafeToLoadUnconditionally(SI->getOperand(2), Align, DL, SI)) { LoadInst *V1 = Builder->CreateLoad(SI->getOperand(1), SI->getOperand(1)->getName()+".val"); LoadInst *V2 = Builder->CreateLoad(SI->getOperand(2), SI->getOperand(2)->getName()+".val"); + assert(LI.isUnordered() && "implied by above"); V1->setAlignment(Align); + V1->setAtomic(LI.getOrdering(), LI.getSynchScope()); V2->setAlignment(Align); + V2->setAtomic(LI.getOrdering(), LI.getSynchScope()); return SelectInst::Create(SI->getCondition(), V1, V2); } @@ -882,6 +917,61 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { return nullptr; } +/// \brief Look for extractelement/insertvalue sequence that acts like a bitcast. +/// +/// \returns underlying value that was "cast", or nullptr otherwise. +/// +/// For example, if we have: +/// +/// %E0 = extractelement <2 x double> %U, i32 0 +/// %V0 = insertvalue [2 x double] undef, double %E0, 0 +/// %E1 = extractelement <2 x double> %U, i32 1 +/// %V1 = insertvalue [2 x double] %V0, double %E1, 1 +/// +/// and the layout of a <2 x double> is isomorphic to a [2 x double], +/// then %V1 can be safely approximated by a conceptual "bitcast" of %U. +/// Note that %U may contain non-undef values where %V1 has undef. +static Value *likeBitCastFromVector(InstCombiner &IC, Value *V) { + Value *U = nullptr; + while (auto *IV = dyn_cast<InsertValueInst>(V)) { + auto *E = dyn_cast<ExtractElementInst>(IV->getInsertedValueOperand()); + if (!E) + return nullptr; + auto *W = E->getVectorOperand(); + if (!U) + U = W; + else if (U != W) + return nullptr; + auto *CI = dyn_cast<ConstantInt>(E->getIndexOperand()); + if (!CI || IV->getNumIndices() != 1 || CI->getZExtValue() != *IV->idx_begin()) + return nullptr; + V = IV->getAggregateOperand(); + } + if (!isa<UndefValue>(V) ||!U) + return nullptr; + + auto *UT = cast<VectorType>(U->getType()); + auto *VT = V->getType(); + // Check that types UT and VT are bitwise isomorphic. + const auto &DL = IC.getDataLayout(); + if (DL.getTypeStoreSizeInBits(UT) != DL.getTypeStoreSizeInBits(VT)) { + return nullptr; + } + if (auto *AT = dyn_cast<ArrayType>(VT)) { + if (AT->getNumElements() != UT->getNumElements()) + return nullptr; + } else { + auto *ST = cast<StructType>(VT); + if (ST->getNumElements() != UT->getNumElements()) + return nullptr; + for (const auto *EltT : ST->elements()) { + if (EltT != UT->getElementType()) + return nullptr; + } + } + return U; +} + /// \brief Combine stores to match the type of value being stored. /// /// The core idea here is that the memory does not have any intrinsic type and @@ -903,9 +993,9 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { /// the store instruction as otherwise there is no way to signal whether it was /// combined or not: IC.EraseInstFromFunction returns a null pointer. static bool combineStoreToValueType(InstCombiner &IC, StoreInst &SI) { - // FIXME: We could probably with some care handle both volatile and atomic - // stores here but it isn't clear that this is important. - if (!SI.isSimple()) + // FIXME: We could probably with some care handle both volatile and ordered + // atomic stores here but it isn't clear that this is important. + if (!SI.isUnordered()) return false; Value *V = SI.getValueOperand(); @@ -917,8 +1007,13 @@ static bool combineStoreToValueType(InstCombiner &IC, StoreInst &SI) { return true; } - // FIXME: We should also canonicalize loads of vectors when their elements are - // cast to other types. + if (Value *U = likeBitCastFromVector(IC, V)) { + combineStoreToNewValue(IC, SI, U); + return true; + } + + // FIXME: We should also canonicalize stores of vectors when their elements + // are cast to other types. return false; } @@ -950,11 +1045,16 @@ static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) { if (SL->hasPadding()) return false; + auto Align = SI.getAlignment(); + if (!Align) + Align = DL.getABITypeAlignment(ST); + SmallString<16> EltName = V->getName(); EltName += ".elt"; auto *Addr = SI.getPointerOperand(); SmallString<16> AddrName = Addr->getName(); AddrName += ".repack"; + auto *IdxType = Type::getInt32Ty(ST->getContext()); auto *Zero = ConstantInt::get(IdxType, 0); for (unsigned i = 0; i < Count; i++) { @@ -962,9 +1062,11 @@ static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) { Zero, ConstantInt::get(IdxType, i), }; - auto *Ptr = IC.Builder->CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), AddrName); + auto *Ptr = IC.Builder->CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), + AddrName); auto *Val = IC.Builder->CreateExtractValue(V, i, EltName); - IC.Builder->CreateStore(Val, Ptr); + auto EltAlign = MinAlign(Align, SL->getElementOffset(i)); + IC.Builder->CreateAlignedStore(Val, Ptr, EltAlign); } return true; @@ -972,11 +1074,43 @@ static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) { if (auto *AT = dyn_cast<ArrayType>(T)) { // If the array only have one element, we unpack. - if (AT->getNumElements() == 1) { + auto NumElements = AT->getNumElements(); + if (NumElements == 1) { V = IC.Builder->CreateExtractValue(V, 0); combineStoreToNewValue(IC, SI, V); return true; } + + const DataLayout &DL = IC.getDataLayout(); + auto EltSize = DL.getTypeAllocSize(AT->getElementType()); + auto Align = SI.getAlignment(); + if (!Align) + Align = DL.getABITypeAlignment(T); + + SmallString<16> EltName = V->getName(); + EltName += ".elt"; + auto *Addr = SI.getPointerOperand(); + SmallString<16> AddrName = Addr->getName(); + AddrName += ".repack"; + + auto *IdxType = Type::getInt64Ty(T->getContext()); + auto *Zero = ConstantInt::get(IdxType, 0); + + uint64_t Offset = 0; + for (uint64_t i = 0; i < NumElements; i++) { + Value *Indices[2] = { + Zero, + ConstantInt::get(IdxType, i), + }; + auto *Ptr = IC.Builder->CreateInBoundsGEP(AT, Addr, makeArrayRef(Indices), + AddrName); + auto *Val = IC.Builder->CreateExtractValue(V, i, EltName); + auto EltAlign = MinAlign(Align, Offset); + IC.Builder->CreateAlignedStore(Val, Ptr, EltAlign); + Offset += EltSize; + } + + return true; } return false; @@ -1017,7 +1151,7 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { // Try to canonicalize the stored type. if (combineStoreToValueType(*this, SI)) - return EraseInstFromFunction(SI); + return eraseInstFromFunction(SI); // Attempt to improve the alignment. unsigned KnownAlign = getOrEnforceKnownAlignment( @@ -1033,7 +1167,7 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { // Try to canonicalize the stored type. if (unpackStoreToAggregate(*this, SI)) - return EraseInstFromFunction(SI); + return eraseInstFromFunction(SI); // Replace GEP indices if possible. if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI)) { @@ -1049,11 +1183,11 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { // alloca dead. if (Ptr->hasOneUse()) { if (isa<AllocaInst>(Ptr)) - return EraseInstFromFunction(SI); + return eraseInstFromFunction(SI); if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr)) { if (isa<AllocaInst>(GEP->getOperand(0))) { if (GEP->getOperand(0)->hasOneUse()) - return EraseInstFromFunction(SI); + return eraseInstFromFunction(SI); } } } @@ -1079,7 +1213,7 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { SI.getOperand(1))) { ++NumDeadStore; ++BBI; - EraseInstFromFunction(*PrevSI); + eraseInstFromFunction(*PrevSI); continue; } break; @@ -1091,7 +1225,7 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { if (LoadInst *LI = dyn_cast<LoadInst>(BBI)) { if (LI == Val && equivalentAddressValues(LI->getOperand(0), Ptr)) { assert(SI.isUnordered() && "can't eliminate ordering operation"); - return EraseInstFromFunction(SI); + return eraseInstFromFunction(SI); } // Otherwise, this is a load from some other location. Stores before it @@ -1116,11 +1250,7 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { // store undef, Ptr -> noop if (isa<UndefValue>(Val)) - return EraseInstFromFunction(SI); - - // The code below needs to be audited and adjusted for unordered atomics - if (!SI.isSimple()) - return nullptr; + return eraseInstFromFunction(SI); // If this store is the last instruction in the basic block (possibly // excepting debug info instructions), and if the block ends with an @@ -1147,6 +1277,9 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { /// into a phi node with a store in the successor. /// bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { + assert(SI.isUnordered() && + "this code has not been auditted for volatile or ordered store case"); + BasicBlock *StoreBB = SI.getParent(); // Check to see if the successor block has exactly two incoming edges. If @@ -1268,7 +1401,7 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { } // Nuke the old stores. - EraseInstFromFunction(SI); - EraseInstFromFunction(*OtherStore); + eraseInstFromFunction(SI); + eraseInstFromFunction(*OtherStore); return true; } diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 160792b0a0000..788097f33f121 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -45,28 +45,28 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, // (PowerOfTwo >>u B) --> isExact since shifting out the result would make it // inexact. Similarly for <<. - if (BinaryOperator *I = dyn_cast<BinaryOperator>(V)) - if (I->isLogicalShift() && - isKnownToBeAPowerOfTwo(I->getOperand(0), IC.getDataLayout(), false, 0, - IC.getAssumptionCache(), &CxtI, - IC.getDominatorTree())) { - // We know that this is an exact/nuw shift and that the input is a - // non-zero context as well. - if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) { - I->setOperand(0, V2); - MadeChange = true; - } + BinaryOperator *I = dyn_cast<BinaryOperator>(V); + if (I && I->isLogicalShift() && + isKnownToBeAPowerOfTwo(I->getOperand(0), IC.getDataLayout(), false, 0, + IC.getAssumptionCache(), &CxtI, + IC.getDominatorTree())) { + // We know that this is an exact/nuw shift and that the input is a + // non-zero context as well. + if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) { + I->setOperand(0, V2); + MadeChange = true; + } - if (I->getOpcode() == Instruction::LShr && !I->isExact()) { - I->setIsExact(); - MadeChange = true; - } + if (I->getOpcode() == Instruction::LShr && !I->isExact()) { + I->setIsExact(); + MadeChange = true; + } - if (I->getOpcode() == Instruction::Shl && !I->hasNoUnsignedWrap()) { - I->setHasNoUnsignedWrap(); - MadeChange = true; - } + if (I->getOpcode() == Instruction::Shl && !I->hasNoUnsignedWrap()) { + I->setHasNoUnsignedWrap(); + MadeChange = true; } + } // TODO: Lots more we could do here: // If V is a phi node, we can call this on each of its operands. @@ -177,13 +177,13 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyMulInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyUsingDistributiveLaws(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // X * -1 == 0 - X if (match(Op1, m_AllOnes())) { @@ -323,7 +323,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (PossiblyExactOperator *SDiv = dyn_cast<PossiblyExactOperator>(BO)) if (SDiv->isExact()) { if (Op1BO == Op1C) - return ReplaceInstUsesWith(I, Op0BO); + return replaceInstUsesWith(I, Op0BO); return BinaryOperator::CreateNeg(Op0BO); } @@ -374,10 +374,13 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { APInt Negative2(I.getType()->getPrimitiveSizeInBits(), (uint64_t)-2, true); Value *BoolCast = nullptr, *OtherOp = nullptr; - if (MaskedValueIsZero(Op0, Negative2, 0, &I)) - BoolCast = Op0, OtherOp = Op1; - else if (MaskedValueIsZero(Op1, Negative2, 0, &I)) - BoolCast = Op1, OtherOp = Op0; + if (MaskedValueIsZero(Op0, Negative2, 0, &I)) { + BoolCast = Op0; + OtherOp = Op1; + } else if (MaskedValueIsZero(Op1, Negative2, 0, &I)) { + BoolCast = Op1; + OtherOp = Op0; + } if (BoolCast) { Value *V = Builder->CreateSub(Constant::getNullValue(I.getType()), @@ -536,14 +539,14 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (isa<Constant>(Op0)) std::swap(Op0, Op1); if (Value *V = SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); bool AllowReassociate = I.hasUnsafeAlgebra(); @@ -574,7 +577,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { // Try to simplify "MDC * Constant" if (isFMulOrFDivWithConstant(Op0)) if (Value *V = foldFMulConst(cast<Instruction>(Op0), C, &I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // (MDC +/- C1) * C => (MDC * C) +/- (C1 * C) Instruction *FAddSub = dyn_cast<Instruction>(Op0); @@ -612,11 +615,22 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { } } - // sqrt(X) * sqrt(X) -> X - if (AllowReassociate && (Op0 == Op1)) - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) - if (II->getIntrinsicID() == Intrinsic::sqrt) - return ReplaceInstUsesWith(I, II->getOperand(0)); + if (Op0 == Op1) { + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) { + // sqrt(X) * sqrt(X) -> X + if (AllowReassociate && II->getIntrinsicID() == Intrinsic::sqrt) + return replaceInstUsesWith(I, II->getOperand(0)); + + // fabs(X) * fabs(X) -> X * X + if (II->getIntrinsicID() == Intrinsic::fabs) { + Instruction *FMulVal = BinaryOperator::CreateFMul(II->getOperand(0), + II->getOperand(0), + I.getName()); + FMulVal->copyFastMathFlags(&I); + return FMulVal; + } + } + } // Under unsafe algebra do: // X * log2(0.5*Y) = X*log2(Y) - X @@ -641,7 +655,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { Value *FMulVal = Builder->CreateFMul(OpX, Log2); Value *FSub = Builder->CreateFSub(FMulVal, OpX); FSub->takeName(&I); - return ReplaceInstUsesWith(I, FSub); + return replaceInstUsesWith(I, FSub); } } @@ -661,7 +675,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (N1) { Value *FMul = Builder->CreateFMul(N0, N1); FMul->takeName(&I); - return ReplaceInstUsesWith(I, FMul); + return replaceInstUsesWith(I, FMul); } if (Opnd0->hasOneUse()) { @@ -669,7 +683,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { Value *T = Builder->CreateFMul(N0, Opnd1); Value *Neg = Builder->CreateFNeg(T); Neg->takeName(&I); - return ReplaceInstUsesWith(I, Neg); + return replaceInstUsesWith(I, Neg); } } @@ -698,7 +712,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { Value *R = Builder->CreateFMul(T, Y); R->takeName(&I); - return ReplaceInstUsesWith(I, R); + return replaceInstUsesWith(I, R); } } } @@ -1043,10 +1057,10 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyUDivInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // Handle the integer div common cases if (Instruction *Common = commonIDivTransforms(I)) @@ -1116,27 +1130,43 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifySDivInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // Handle the integer div common cases if (Instruction *Common = commonIDivTransforms(I)) return Common; - // sdiv X, -1 == -X - if (match(Op1, m_AllOnes())) - return BinaryOperator::CreateNeg(Op0); + const APInt *Op1C; + if (match(Op1, m_APInt(Op1C))) { + // sdiv X, -1 == -X + if (Op1C->isAllOnesValue()) + return BinaryOperator::CreateNeg(Op0); - if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { - // sdiv X, C --> ashr exact X, log2(C) - if (I.isExact() && RHS->getValue().isNonNegative() && - RHS->getValue().isPowerOf2()) { - Value *ShAmt = llvm::ConstantInt::get(RHS->getType(), - RHS->getValue().exactLogBase2()); + // sdiv exact X, C --> ashr exact X, log2(C) + if (I.isExact() && Op1C->isNonNegative() && Op1C->isPowerOf2()) { + Value *ShAmt = ConstantInt::get(Op1->getType(), Op1C->exactLogBase2()); return BinaryOperator::CreateExactAShr(Op0, ShAmt, I.getName()); } + + // If the dividend is sign-extended and the constant divisor is small enough + // to fit in the source type, shrink the division to the narrower type: + // (sext X) sdiv C --> sext (X sdiv C) + Value *Op0Src; + if (match(Op0, m_OneUse(m_SExt(m_Value(Op0Src)))) && + Op0Src->getType()->getScalarSizeInBits() >= Op1C->getMinSignedBits()) { + + // In the general case, we need to make sure that the dividend is not the + // minimum signed value because dividing that by -1 is UB. But here, we + // know that the -1 divisor case is already handled above. + + Constant *NarrowDivisor = + ConstantExpr::getTrunc(cast<Constant>(Op1), Op0Src->getType()); + Value *NarrowOp = Builder->CreateSDiv(Op0Src, NarrowDivisor); + return new SExtInst(NarrowOp, Op0->getType()); + } } if (Constant *RHS = dyn_cast<Constant>(Op1)) { @@ -1214,11 +1244,11 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyFDivInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (isa<Constant>(Op0)) if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) @@ -1363,8 +1393,17 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { if (Instruction *R = FoldOpIntoSelect(I, SI)) return R; } else if (isa<PHINode>(Op0I)) { - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; + using namespace llvm::PatternMatch; + const APInt *Op1Int; + if (match(Op1, m_APInt(Op1Int)) && !Op1Int->isMinValue() && + (I.getOpcode() == Instruction::URem || + !Op1Int->isMinSignedValue())) { + // FoldOpIntoPhi will speculate instructions to the end of the PHI's + // predecessor blocks, so do this only if we know the srem or urem + // will not fault. + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } } // See if we can fold away this rem instruction. @@ -1380,10 +1419,10 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyURemInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Instruction *common = commonIRemTransforms(I)) return common; @@ -1405,7 +1444,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { if (match(Op0, m_One())) { Value *Cmp = Builder->CreateICmpNE(Op1, Op0); Value *Ext = Builder->CreateZExt(Cmp, I.getType()); - return ReplaceInstUsesWith(I, Ext); + return replaceInstUsesWith(I, Ext); } return nullptr; @@ -1415,10 +1454,10 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifySRemInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // Handle the integer rem common cases if (Instruction *Common = commonIRemTransforms(I)) @@ -1490,11 +1529,11 @@ Instruction *InstCombiner::visitFRem(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyFRemInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // Handle cases involving: rem X, (select Cond, Y, Z) if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I)) diff --git a/lib/Transforms/InstCombine/InstCombinePHI.cpp b/lib/Transforms/InstCombine/InstCombinePHI.cpp index f1aa98b5e3595..79a4912332ff3 100644 --- a/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -15,8 +15,11 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; +using namespace llvm::PatternMatch; #define DEBUG_TYPE "instcombine" @@ -32,15 +35,6 @@ Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { Type *LHSType = LHSVal->getType(); Type *RHSType = RHSVal->getType(); - bool isNUW = false, isNSW = false, isExact = false; - if (OverflowingBinaryOperator *BO = - dyn_cast<OverflowingBinaryOperator>(FirstInst)) { - isNUW = BO->hasNoUnsignedWrap(); - isNSW = BO->hasNoSignedWrap(); - } else if (PossiblyExactOperator *PEO = - dyn_cast<PossiblyExactOperator>(FirstInst)) - isExact = PEO->isExact(); - // Scan to see if all operands are the same opcode, and all have one use. for (unsigned i = 1; i != PN.getNumIncomingValues(); ++i) { Instruction *I = dyn_cast<Instruction>(PN.getIncomingValue(i)); @@ -56,13 +50,6 @@ Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { if (CI->getPredicate() != cast<CmpInst>(FirstInst)->getPredicate()) return nullptr; - if (isNUW) - isNUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap(); - if (isNSW) - isNSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); - if (isExact) - isExact = cast<PossiblyExactOperator>(I)->isExact(); - // Keep track of which operand needs a phi node. if (I->getOperand(0) != LHSVal) LHSVal = nullptr; if (I->getOperand(1) != RHSVal) RHSVal = nullptr; @@ -121,9 +108,12 @@ Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { BinaryOperator *BinOp = cast<BinaryOperator>(FirstInst); BinaryOperator *NewBinOp = BinaryOperator::Create(BinOp->getOpcode(), LHSVal, RHSVal); - if (isNUW) NewBinOp->setHasNoUnsignedWrap(); - if (isNSW) NewBinOp->setHasNoSignedWrap(); - if (isExact) NewBinOp->setIsExact(); + + NewBinOp->copyIRFlags(PN.getIncomingValue(0)); + + for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) + NewBinOp->andIRFlags(PN.getIncomingValue(i)); + NewBinOp->setDebugLoc(FirstInst->getDebugLoc()); return NewBinOp; } @@ -494,7 +484,6 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { // code size and simplifying code. Constant *ConstantOp = nullptr; Type *CastSrcTy = nullptr; - bool isNUW = false, isNSW = false, isExact = false; if (isa<CastInst>(FirstInst)) { CastSrcTy = FirstInst->getOperand(0)->getType(); @@ -511,14 +500,6 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { ConstantOp = dyn_cast<Constant>(FirstInst->getOperand(1)); if (!ConstantOp) return FoldPHIArgBinOpIntoPHI(PN); - - if (OverflowingBinaryOperator *BO = - dyn_cast<OverflowingBinaryOperator>(FirstInst)) { - isNUW = BO->hasNoUnsignedWrap(); - isNSW = BO->hasNoSignedWrap(); - } else if (PossiblyExactOperator *PEO = - dyn_cast<PossiblyExactOperator>(FirstInst)) - isExact = PEO->isExact(); } else { return nullptr; // Cannot fold this operation. } @@ -534,13 +515,6 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { } else if (I->getOperand(1) != ConstantOp) { return nullptr; } - - if (isNUW) - isNUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap(); - if (isNSW) - isNSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); - if (isExact) - isExact = cast<PossiblyExactOperator>(I)->isExact(); } // Okay, they are all the same operation. Create a new PHI node of the @@ -581,9 +555,11 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { if (BinaryOperator *BinOp = dyn_cast<BinaryOperator>(FirstInst)) { BinOp = BinaryOperator::Create(BinOp->getOpcode(), PhiVal, ConstantOp); - if (isNUW) BinOp->setHasNoUnsignedWrap(); - if (isNSW) BinOp->setHasNoSignedWrap(); - if (isExact) BinOp->setIsExact(); + BinOp->copyIRFlags(PN.getIncomingValue(0)); + + for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) + BinOp->andIRFlags(PN.getIncomingValue(i)); + BinOp->setDebugLoc(FirstInst->getDebugLoc()); return BinOp; } @@ -641,6 +617,16 @@ static bool PHIsEqualValue(PHINode *PN, Value *NonPhiInVal, return true; } +/// Return an existing non-zero constant if this phi node has one, otherwise +/// return constant 1. +static ConstantInt *GetAnyNonZeroConstInt(PHINode &PN) { + assert(isa<IntegerType>(PN.getType()) && "Expect only intger type phi"); + for (Value *V : PN.operands()) + if (auto *ConstVA = dyn_cast<ConstantInt>(V)) + if (!ConstVA->isZeroValue()) + return ConstVA; + return ConstantInt::get(cast<IntegerType>(PN.getType()), 1); +} namespace { struct PHIUsageRecord { @@ -768,7 +754,7 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { // If we have no users, they must be all self uses, just nuke the PHI. if (PHIUsers.empty()) - return ReplaceInstUsesWith(FirstPhi, UndefValue::get(FirstPhi.getType())); + return replaceInstUsesWith(FirstPhi, UndefValue::get(FirstPhi.getType())); // If this phi node is transformable, create new PHIs for all the pieces // extracted out of it. First, sort the users by their offset and size. @@ -864,22 +850,22 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { } // Replace the use of this piece with the PHI node. - ReplaceInstUsesWith(*PHIUsers[UserI].Inst, EltPHI); + replaceInstUsesWith(*PHIUsers[UserI].Inst, EltPHI); } // Replace all the remaining uses of the PHI nodes (self uses and the lshrs) // with undefs. Value *Undef = UndefValue::get(FirstPhi.getType()); for (unsigned i = 1, e = PHIsToSlice.size(); i != e; ++i) - ReplaceInstUsesWith(*PHIsToSlice[i], Undef); - return ReplaceInstUsesWith(FirstPhi, Undef); + replaceInstUsesWith(*PHIsToSlice[i], Undef); + return replaceInstUsesWith(FirstPhi, Undef); } // PHINode simplification // Instruction *InstCombiner::visitPHINode(PHINode &PN) { if (Value *V = SimplifyInstruction(&PN, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(PN, V); + return replaceInstUsesWith(PN, V); if (Instruction *Result = FoldPHIArgZextsIntoPHI(PN)) return Result; @@ -905,7 +891,7 @@ Instruction *InstCombiner::visitPHINode(PHINode &PN) { SmallPtrSet<PHINode*, 16> PotentiallyDeadPHIs; PotentiallyDeadPHIs.insert(&PN); if (DeadPHICycle(PU, PotentiallyDeadPHIs)) - return ReplaceInstUsesWith(PN, UndefValue::get(PN.getType())); + return replaceInstUsesWith(PN, UndefValue::get(PN.getType())); } // If this phi has a single use, and if that use just computes a value for @@ -917,7 +903,30 @@ Instruction *InstCombiner::visitPHINode(PHINode &PN) { if (PHIUser->hasOneUse() && (isa<BinaryOperator>(PHIUser) || isa<GetElementPtrInst>(PHIUser)) && PHIUser->user_back() == &PN) { - return ReplaceInstUsesWith(PN, UndefValue::get(PN.getType())); + return replaceInstUsesWith(PN, UndefValue::get(PN.getType())); + } + // When a PHI is used only to be compared with zero, it is safe to replace + // an incoming value proved as known nonzero with any non-zero constant. + // For example, in the code below, the incoming value %v can be replaced + // with any non-zero constant based on the fact that the PHI is only used to + // be compared with zero and %v is a known non-zero value: + // %v = select %cond, 1, 2 + // %p = phi [%v, BB] ... + // icmp eq, %p, 0 + auto *CmpInst = dyn_cast<ICmpInst>(PHIUser); + // FIXME: To be simple, handle only integer type for now. + if (CmpInst && isa<IntegerType>(PN.getType()) && CmpInst->isEquality() && + match(CmpInst->getOperand(1), m_Zero())) { + ConstantInt *NonZeroConst = nullptr; + for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { + Instruction *CtxI = PN.getIncomingBlock(i)->getTerminator(); + Value *VA = PN.getIncomingValue(i); + if (isKnownNonZero(VA, DL, 0, AC, CtxI, DT)) { + if (!NonZeroConst) + NonZeroConst = GetAnyNonZeroConstInt(PN); + PN.setIncomingValue(i, NonZeroConst); + } + } } } @@ -951,7 +960,7 @@ Instruction *InstCombiner::visitPHINode(PHINode &PN) { if (InValNo == NumIncomingVals) { SmallPtrSet<PHINode*, 16> ValueEqualPHIs; if (PHIsEqualValue(&PN, NonPhiInVal, ValueEqualPHIs)) - return ReplaceInstUsesWith(PN, NonPhiInVal); + return replaceInstUsesWith(PN, NonPhiInVal); } } } diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 51219bcb0b7ba..d7eed790e2ab2 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -116,25 +116,41 @@ static Constant *GetSelectFoldableConstant(Instruction *I) { } } -/// Here we have (select c, TI, FI), and we know that TI and FI -/// have the same opcode and only one use each. Try to simplify this. +/// We have (select c, TI, FI), and we know that TI and FI have the same opcode. Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI) { - if (TI->getNumOperands() == 1) { - // If this is a non-volatile load or a cast from the same type, - // merge. - if (TI->isCast()) { - Type *FIOpndTy = FI->getOperand(0)->getType(); - if (TI->getOperand(0)->getType() != FIOpndTy) + // If this is a cast from the same type, merge. + if (TI->getNumOperands() == 1 && TI->isCast()) { + Type *FIOpndTy = FI->getOperand(0)->getType(); + if (TI->getOperand(0)->getType() != FIOpndTy) + return nullptr; + + // The select condition may be a vector. We may only change the operand + // type if the vector width remains the same (and matches the condition). + Type *CondTy = SI.getCondition()->getType(); + if (CondTy->isVectorTy()) { + if (!FIOpndTy->isVectorTy()) return nullptr; - // The select condition may be a vector. We may only change the operand - // type if the vector width remains the same (and matches the condition). - Type *CondTy = SI.getCondition()->getType(); - if (CondTy->isVectorTy() && (!FIOpndTy->isVectorTy() || - CondTy->getVectorNumElements() != FIOpndTy->getVectorNumElements())) + if (CondTy->getVectorNumElements() != FIOpndTy->getVectorNumElements()) return nullptr; - } else { - return nullptr; // unknown unary op. + + // TODO: If the backend knew how to deal with casts better, we could + // remove this limitation. For now, there's too much potential to create + // worse codegen by promoting the select ahead of size-altering casts + // (PR28160). + // + // Note that ValueTracking's matchSelectPattern() looks through casts + // without checking 'hasOneUse' when it matches min/max patterns, so this + // transform may end up happening anyway. + if (TI->getOpcode() != Instruction::BitCast && + (!TI->hasOneUse() || !FI->hasOneUse())) + return nullptr; + + } else if (!TI->hasOneUse() || !FI->hasOneUse()) { + // TODO: The one-use restrictions for a scalar select could be eased if + // the fold of a select in visitLoadInst() was enhanced to match a pattern + // that includes a cast. + return nullptr; } // Fold this by inserting a select from the input values. @@ -144,8 +160,13 @@ Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI, TI->getType()); } - // Only handle binary operators here. - if (!isa<BinaryOperator>(TI)) + // TODO: This function ends awkwardly in unreachable - fix to be more normal. + + // Only handle binary operators with one-use here. As with the cast case + // above, it may be possible to relax the one-use constraint, but that needs + // be examined carefully since it may not reduce the total number of + // instructions. + if (!isa<BinaryOperator>(TI) || !TI->hasOneUse() || !FI->hasOneUse()) return nullptr; // Figure out if the operations have any operands in common. @@ -231,12 +252,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, BinaryOperator *TVI_BO = cast<BinaryOperator>(TVI); BinaryOperator *BO = BinaryOperator::Create(TVI_BO->getOpcode(), FalseVal, NewSel); - if (isa<PossiblyExactOperator>(BO)) - BO->setIsExact(TVI_BO->isExact()); - if (isa<OverflowingBinaryOperator>(BO)) { - BO->setHasNoUnsignedWrap(TVI_BO->hasNoUnsignedWrap()); - BO->setHasNoSignedWrap(TVI_BO->hasNoSignedWrap()); - } + BO->copyIRFlags(TVI_BO); return BO; } } @@ -266,12 +282,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, BinaryOperator *FVI_BO = cast<BinaryOperator>(FVI); BinaryOperator *BO = BinaryOperator::Create(FVI_BO->getOpcode(), TrueVal, NewSel); - if (isa<PossiblyExactOperator>(BO)) - BO->setIsExact(FVI_BO->isExact()); - if (isa<OverflowingBinaryOperator>(BO)) { - BO->setHasNoUnsignedWrap(FVI_BO->hasNoUnsignedWrap()); - BO->setHasNoSignedWrap(FVI_BO->hasNoSignedWrap()); - } + BO->copyIRFlags(FVI_BO); return BO; } } @@ -353,7 +364,7 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, /// %1 = icmp ne i32 %x, 0 /// %2 = select i1 %1, i32 %0, i32 32 /// \code -/// +/// /// into: /// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 false) static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, @@ -519,10 +530,10 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, // Check if we can express the operation with a single or. if (C2->isAllOnesValue()) - return ReplaceInstUsesWith(SI, Builder->CreateOr(AShr, C1)); + return replaceInstUsesWith(SI, Builder->CreateOr(AShr, C1)); Value *And = Builder->CreateAnd(AShr, C2->getValue()-C1->getValue()); - return ReplaceInstUsesWith(SI, Builder->CreateAdd(And, C1)); + return replaceInstUsesWith(SI, Builder->CreateAdd(And, C1)); } } } @@ -585,15 +596,15 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, V = Builder->CreateOr(X, *Y); if (V) - return ReplaceInstUsesWith(SI, V); + return replaceInstUsesWith(SI, V); } } if (Value *V = foldSelectICmpAndOr(SI, TrueVal, FalseVal, Builder)) - return ReplaceInstUsesWith(SI, V); + return replaceInstUsesWith(SI, V); if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder)) - return ReplaceInstUsesWith(SI, V); + return replaceInstUsesWith(SI, V); return Changed ? &SI : nullptr; } @@ -642,11 +653,14 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, Value *A, Value *B, Instruction &Outer, SelectPatternFlavor SPF2, Value *C) { + if (Outer.getType() != Inner->getType()) + return nullptr; + if (C == A || C == B) { // MAX(MAX(A, B), B) -> MAX(A, B) // MIN(MIN(a, b), a) -> MIN(a, b) if (SPF1 == SPF2) - return ReplaceInstUsesWith(Outer, Inner); + return replaceInstUsesWith(Outer, Inner); // MAX(MIN(a, b), a) -> a // MIN(MAX(a, b), a) -> a @@ -654,14 +668,14 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, (SPF1 == SPF_SMAX && SPF2 == SPF_SMIN) || (SPF1 == SPF_UMIN && SPF2 == SPF_UMAX) || (SPF1 == SPF_UMAX && SPF2 == SPF_UMIN)) - return ReplaceInstUsesWith(Outer, C); + return replaceInstUsesWith(Outer, C); } if (SPF1 == SPF2) { if (ConstantInt *CB = dyn_cast<ConstantInt>(B)) { if (ConstantInt *CC = dyn_cast<ConstantInt>(C)) { - APInt ACB = CB->getValue(); - APInt ACC = CC->getValue(); + const APInt &ACB = CB->getValue(); + const APInt &ACC = CC->getValue(); // MIN(MIN(A, 23), 97) -> MIN(A, 23) // MAX(MAX(A, 97), 23) -> MAX(A, 97) @@ -669,7 +683,7 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, (SPF1 == SPF_SMIN && ACB.sle(ACC)) || (SPF1 == SPF_UMAX && ACB.uge(ACC)) || (SPF1 == SPF_SMAX && ACB.sge(ACC))) - return ReplaceInstUsesWith(Outer, Inner); + return replaceInstUsesWith(Outer, Inner); // MIN(MIN(A, 97), 23) -> MIN(A, 23) // MAX(MAX(A, 23), 97) -> MAX(A, 97) @@ -687,7 +701,7 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, // ABS(ABS(X)) -> ABS(X) // NABS(NABS(X)) -> NABS(X) if (SPF1 == SPF2 && (SPF1 == SPF_ABS || SPF1 == SPF_NABS)) { - return ReplaceInstUsesWith(Outer, Inner); + return replaceInstUsesWith(Outer, Inner); } // ABS(NABS(X)) -> ABS(X) @@ -697,7 +711,7 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, SelectInst *SI = cast<SelectInst>(Inner); Value *NewSI = Builder->CreateSelect( SI->getCondition(), SI->getFalseValue(), SI->getTrueValue()); - return ReplaceInstUsesWith(Outer, NewSI); + return replaceInstUsesWith(Outer, NewSI); } auto IsFreeOrProfitableToInvert = @@ -742,7 +756,7 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, Builder, getInverseMinMaxSelectPattern(SPF1), NotA, NotB); Value *NewOuter = Builder->CreateNot(generateMinMaxSelectPattern( Builder, getInverseMinMaxSelectPattern(SPF2), NewInner, NotC)); - return ReplaceInstUsesWith(Outer, NewOuter); + return replaceInstUsesWith(Outer, NewOuter); } return nullptr; @@ -823,76 +837,156 @@ static Value *foldSelectICmpAnd(const SelectInst &SI, ConstantInt *TrueVal, return V; } +/// Turn select C, (X + Y), (X - Y) --> (X + (select C, Y, (-Y))). +/// This is even legal for FP. +static Instruction *foldAddSubSelect(SelectInst &SI, + InstCombiner::BuilderTy &Builder) { + Value *CondVal = SI.getCondition(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + auto *TI = dyn_cast<Instruction>(TrueVal); + auto *FI = dyn_cast<Instruction>(FalseVal); + if (!TI || !FI || !TI->hasOneUse() || !FI->hasOneUse()) + return nullptr; + + Instruction *AddOp = nullptr, *SubOp = nullptr; + if ((TI->getOpcode() == Instruction::Sub && + FI->getOpcode() == Instruction::Add) || + (TI->getOpcode() == Instruction::FSub && + FI->getOpcode() == Instruction::FAdd)) { + AddOp = FI; + SubOp = TI; + } else if ((FI->getOpcode() == Instruction::Sub && + TI->getOpcode() == Instruction::Add) || + (FI->getOpcode() == Instruction::FSub && + TI->getOpcode() == Instruction::FAdd)) { + AddOp = TI; + SubOp = FI; + } + + if (AddOp) { + Value *OtherAddOp = nullptr; + if (SubOp->getOperand(0) == AddOp->getOperand(0)) { + OtherAddOp = AddOp->getOperand(1); + } else if (SubOp->getOperand(0) == AddOp->getOperand(1)) { + OtherAddOp = AddOp->getOperand(0); + } + + if (OtherAddOp) { + // So at this point we know we have (Y -> OtherAddOp): + // select C, (add X, Y), (sub X, Z) + Value *NegVal; // Compute -Z + if (SI.getType()->isFPOrFPVectorTy()) { + NegVal = Builder.CreateFNeg(SubOp->getOperand(1)); + if (Instruction *NegInst = dyn_cast<Instruction>(NegVal)) { + FastMathFlags Flags = AddOp->getFastMathFlags(); + Flags &= SubOp->getFastMathFlags(); + NegInst->setFastMathFlags(Flags); + } + } else { + NegVal = Builder.CreateNeg(SubOp->getOperand(1)); + } + + Value *NewTrueOp = OtherAddOp; + Value *NewFalseOp = NegVal; + if (AddOp != TI) + std::swap(NewTrueOp, NewFalseOp); + Value *NewSel = Builder.CreateSelect(CondVal, NewTrueOp, NewFalseOp, + SI.getName() + ".p"); + + if (SI.getType()->isFPOrFPVectorTy()) { + Instruction *RI = + BinaryOperator::CreateFAdd(SubOp->getOperand(0), NewSel); + + FastMathFlags Flags = AddOp->getFastMathFlags(); + Flags &= SubOp->getFastMathFlags(); + RI->setFastMathFlags(Flags); + return RI; + } else + return BinaryOperator::CreateAdd(SubOp->getOperand(0), NewSel); + } + } + return nullptr; +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); + Type *SelType = SI.getType(); if (Value *V = SimplifySelectInst(CondVal, TrueVal, FalseVal, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(SI, V); + return replaceInstUsesWith(SI, V); - if (SI.getType()->isIntegerTy(1)) { - if (ConstantInt *C = dyn_cast<ConstantInt>(TrueVal)) { - if (C->getZExtValue()) { - // Change: A = select B, true, C --> A = or B, C - return BinaryOperator::CreateOr(CondVal, FalseVal); - } + if (SelType->getScalarType()->isIntegerTy(1) && + TrueVal->getType() == CondVal->getType()) { + if (match(TrueVal, m_One())) { + // Change: A = select B, true, C --> A = or B, C + return BinaryOperator::CreateOr(CondVal, FalseVal); + } + if (match(TrueVal, m_Zero())) { // Change: A = select B, false, C --> A = and !B, C - Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName()); + Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName()); return BinaryOperator::CreateAnd(NotCond, FalseVal); } - if (ConstantInt *C = dyn_cast<ConstantInt>(FalseVal)) { - if (!C->getZExtValue()) { - // Change: A = select B, C, false --> A = and B, C - return BinaryOperator::CreateAnd(CondVal, TrueVal); - } + if (match(FalseVal, m_Zero())) { + // Change: A = select B, C, false --> A = and B, C + return BinaryOperator::CreateAnd(CondVal, TrueVal); + } + if (match(FalseVal, m_One())) { // Change: A = select B, C, true --> A = or !B, C - Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName()); + Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName()); return BinaryOperator::CreateOr(NotCond, TrueVal); } - // select a, b, a -> a&b - // select a, a, b -> a|b + // select a, a, b -> a | b + // select a, b, a -> a & b if (CondVal == TrueVal) return BinaryOperator::CreateOr(CondVal, FalseVal); if (CondVal == FalseVal) return BinaryOperator::CreateAnd(CondVal, TrueVal); - // select a, ~a, b -> (~a)&b - // select a, b, ~a -> (~a)|b + // select a, ~a, b -> (~a) & b + // select a, b, ~a -> (~a) | b if (match(TrueVal, m_Not(m_Specific(CondVal)))) return BinaryOperator::CreateAnd(TrueVal, FalseVal); if (match(FalseVal, m_Not(m_Specific(CondVal)))) return BinaryOperator::CreateOr(TrueVal, FalseVal); } - // Selecting between two integer constants? - if (ConstantInt *TrueValC = dyn_cast<ConstantInt>(TrueVal)) - if (ConstantInt *FalseValC = dyn_cast<ConstantInt>(FalseVal)) { - // select C, 1, 0 -> zext C to int - if (FalseValC->isZero() && TrueValC->getValue() == 1) - return new ZExtInst(CondVal, SI.getType()); - - // select C, -1, 0 -> sext C to int - if (FalseValC->isZero() && TrueValC->isAllOnesValue()) - return new SExtInst(CondVal, SI.getType()); - - // select C, 0, 1 -> zext !C to int - if (TrueValC->isZero() && FalseValC->getValue() == 1) { - Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName()); - return new ZExtInst(NotCond, SI.getType()); - } + // Selecting between two integer or vector splat integer constants? + // + // Note that we don't handle a scalar select of vectors: + // select i1 %c, <2 x i8> <1, 1>, <2 x i8> <0, 0> + // because that may need 3 instructions to splat the condition value: + // extend, insertelement, shufflevector. + if (CondVal->getType()->isVectorTy() == SelType->isVectorTy()) { + // select C, 1, 0 -> zext C to int + if (match(TrueVal, m_One()) && match(FalseVal, m_Zero())) + return new ZExtInst(CondVal, SelType); + + // select C, -1, 0 -> sext C to int + if (match(TrueVal, m_AllOnes()) && match(FalseVal, m_Zero())) + return new SExtInst(CondVal, SelType); + + // select C, 0, 1 -> zext !C to int + if (match(TrueVal, m_Zero()) && match(FalseVal, m_One())) { + Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName()); + return new ZExtInst(NotCond, SelType); + } - // select C, 0, -1 -> sext !C to int - if (TrueValC->isZero() && FalseValC->isAllOnesValue()) { - Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName()); - return new SExtInst(NotCond, SI.getType()); - } + // select C, 0, -1 -> sext !C to int + if (match(TrueVal, m_Zero()) && match(FalseVal, m_AllOnes())) { + Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName()); + return new SExtInst(NotCond, SelType); + } + } + if (ConstantInt *TrueValC = dyn_cast<ConstantInt>(TrueVal)) + if (ConstantInt *FalseValC = dyn_cast<ConstantInt>(FalseVal)) if (Value *V = foldSelectICmpAnd(SI, TrueValC, FalseValC, Builder)) - return ReplaceInstUsesWith(SI, V); - } + return replaceInstUsesWith(SI, V); // See if we are selecting two values based on a comparison of the two values. if (FCmpInst *FCI = dyn_cast<FCmpInst>(CondVal)) { @@ -907,7 +1001,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { !CFPt->getValueAPF().isZero()) || ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && !CFPf->getValueAPF().isZero())) - return ReplaceInstUsesWith(SI, FalseVal); + return replaceInstUsesWith(SI, FalseVal); } // Transform (X une Y) ? X : Y -> X if (FCI->getPredicate() == FCmpInst::FCMP_UNE) { @@ -919,7 +1013,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { !CFPt->getValueAPF().isZero()) || ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && !CFPf->getValueAPF().isZero())) - return ReplaceInstUsesWith(SI, TrueVal); + return replaceInstUsesWith(SI, TrueVal); } // Canonicalize to use ordered comparisons by swapping the select @@ -950,7 +1044,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { !CFPt->getValueAPF().isZero()) || ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && !CFPf->getValueAPF().isZero())) - return ReplaceInstUsesWith(SI, FalseVal); + return replaceInstUsesWith(SI, FalseVal); } // Transform (X une Y) ? Y : X -> Y if (FCI->getPredicate() == FCmpInst::FCMP_UNE) { @@ -962,7 +1056,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { !CFPt->getValueAPF().isZero()) || ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && !CFPf->getValueAPF().isZero())) - return ReplaceInstUsesWith(SI, TrueVal); + return replaceInstUsesWith(SI, TrueVal); } // Canonicalize to use ordered comparisons by swapping the select @@ -991,77 +1085,18 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *Result = visitSelectInstWithICmp(SI, ICI)) return Result; - if (Instruction *TI = dyn_cast<Instruction>(TrueVal)) - if (Instruction *FI = dyn_cast<Instruction>(FalseVal)) - if (TI->hasOneUse() && FI->hasOneUse()) { - Instruction *AddOp = nullptr, *SubOp = nullptr; - - // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) - if (TI->getOpcode() == FI->getOpcode()) - if (Instruction *IV = FoldSelectOpOp(SI, TI, FI)) - return IV; - - // Turn select C, (X+Y), (X-Y) --> (X+(select C, Y, (-Y))). This is - // even legal for FP. - if ((TI->getOpcode() == Instruction::Sub && - FI->getOpcode() == Instruction::Add) || - (TI->getOpcode() == Instruction::FSub && - FI->getOpcode() == Instruction::FAdd)) { - AddOp = FI; SubOp = TI; - } else if ((FI->getOpcode() == Instruction::Sub && - TI->getOpcode() == Instruction::Add) || - (FI->getOpcode() == Instruction::FSub && - TI->getOpcode() == Instruction::FAdd)) { - AddOp = TI; SubOp = FI; - } + if (Instruction *Add = foldAddSubSelect(SI, *Builder)) + return Add; - if (AddOp) { - Value *OtherAddOp = nullptr; - if (SubOp->getOperand(0) == AddOp->getOperand(0)) { - OtherAddOp = AddOp->getOperand(1); - } else if (SubOp->getOperand(0) == AddOp->getOperand(1)) { - OtherAddOp = AddOp->getOperand(0); - } - - if (OtherAddOp) { - // So at this point we know we have (Y -> OtherAddOp): - // select C, (add X, Y), (sub X, Z) - Value *NegVal; // Compute -Z - if (SI.getType()->isFPOrFPVectorTy()) { - NegVal = Builder->CreateFNeg(SubOp->getOperand(1)); - if (Instruction *NegInst = dyn_cast<Instruction>(NegVal)) { - FastMathFlags Flags = AddOp->getFastMathFlags(); - Flags &= SubOp->getFastMathFlags(); - NegInst->setFastMathFlags(Flags); - } - } else { - NegVal = Builder->CreateNeg(SubOp->getOperand(1)); - } - - Value *NewTrueOp = OtherAddOp; - Value *NewFalseOp = NegVal; - if (AddOp != TI) - std::swap(NewTrueOp, NewFalseOp); - Value *NewSel = - Builder->CreateSelect(CondVal, NewTrueOp, - NewFalseOp, SI.getName() + ".p"); - - if (SI.getType()->isFPOrFPVectorTy()) { - Instruction *RI = - BinaryOperator::CreateFAdd(SubOp->getOperand(0), NewSel); - - FastMathFlags Flags = AddOp->getFastMathFlags(); - Flags &= SubOp->getFastMathFlags(); - RI->setFastMathFlags(Flags); - return RI; - } else - return BinaryOperator::CreateAdd(SubOp->getOperand(0), NewSel); - } - } - } + // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) + auto *TI = dyn_cast<Instruction>(TrueVal); + auto *FI = dyn_cast<Instruction>(FalseVal); + if (TI && FI && TI->getOpcode() == FI->getOpcode()) + if (Instruction *IV = FoldSelectOpOp(SI, TI, FI)) + return IV; // See if we can fold the select into one of our operands. - if (SI.getType()->isIntOrIntVectorTy() || SI.getType()->isFPOrFPVectorTy()) { + if (SelType->isIntOrIntVectorTy() || SelType->isFPOrFPVectorTy()) { if (Instruction *FoldI = FoldSelectIntoOp(SI, TrueVal, FalseVal)) return FoldI; @@ -1073,7 +1108,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (SelectPatternResult::isMinOrMax(SPF)) { // Canonicalize so that type casts are outside select patterns. if (LHS->getType()->getPrimitiveSizeInBits() != - SI.getType()->getPrimitiveSizeInBits()) { + SelType->getPrimitiveSizeInBits()) { CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF, SPR.Ordered); Value *Cmp; @@ -1088,8 +1123,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *NewSI = Builder->CreateCast(CastOp, Builder->CreateSelect(Cmp, LHS, RHS), - SI.getType()); - return ReplaceInstUsesWith(SI, NewSI); + SelType); + return replaceInstUsesWith(SI, NewSI); } } @@ -1132,7 +1167,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { : Builder->CreateICmpULT(NewLHS, NewRHS); Value *NewSI = Builder->CreateNot(Builder->CreateSelect(NewCmp, NewLHS, NewRHS)); - return ReplaceInstUsesWith(SI, NewSI); + return replaceInstUsesWith(SI, NewSI); } } } @@ -1195,18 +1230,36 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return &SI; } - if (VectorType* VecTy = dyn_cast<VectorType>(SI.getType())) { + if (VectorType* VecTy = dyn_cast<VectorType>(SelType)) { unsigned VWidth = VecTy->getNumElements(); APInt UndefElts(VWidth, 0); APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); if (Value *V = SimplifyDemandedVectorElts(&SI, AllOnesEltMask, UndefElts)) { if (V != &SI) - return ReplaceInstUsesWith(SI, V); + return replaceInstUsesWith(SI, V); return &SI; } if (isa<ConstantAggregateZero>(CondVal)) { - return ReplaceInstUsesWith(SI, FalseVal); + return replaceInstUsesWith(SI, FalseVal); + } + } + + // See if we can determine the result of this select based on a dominating + // condition. + BasicBlock *Parent = SI.getParent(); + if (BasicBlock *Dom = Parent->getSinglePredecessor()) { + auto *PBI = dyn_cast_or_null<BranchInst>(Dom->getTerminator()); + if (PBI && PBI->isConditional() && + PBI->getSuccessor(0) != PBI->getSuccessor(1) && + (PBI->getSuccessor(0) == Parent || PBI->getSuccessor(1) == Parent)) { + bool CondIsFalse = PBI->getSuccessor(1) == Parent; + Optional<bool> Implication = isImpliedCondition( + PBI->getCondition(), SI.getCondition(), DL, CondIsFalse); + if (Implication) { + Value *V = *Implication ? TrueVal : FalseVal; + return replaceInstUsesWith(SI, V); + } } } diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index 0c7defa5fff83..08e16a7ee1af4 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -55,6 +55,51 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { return nullptr; } +/// Return true if we can simplify two logical (either left or right) shifts +/// that have constant shift amounts. +static bool canEvaluateShiftedShift(unsigned FirstShiftAmt, + bool IsFirstShiftLeft, + Instruction *SecondShift, InstCombiner &IC, + Instruction *CxtI) { + assert(SecondShift->isLogicalShift() && "Unexpected instruction type"); + + // We need constant shifts. + auto *SecondShiftConst = dyn_cast<ConstantInt>(SecondShift->getOperand(1)); + if (!SecondShiftConst) + return false; + + unsigned SecondShiftAmt = SecondShiftConst->getZExtValue(); + bool IsSecondShiftLeft = SecondShift->getOpcode() == Instruction::Shl; + + // We can always fold shl(c1) + shl(c2) -> shl(c1+c2). + // We can always fold lshr(c1) + lshr(c2) -> lshr(c1+c2). + if (IsFirstShiftLeft == IsSecondShiftLeft) + return true; + + // We can always fold lshr(c) + shl(c) -> and(c2). + // We can always fold shl(c) + lshr(c) -> and(c2). + if (FirstShiftAmt == SecondShiftAmt) + return true; + + unsigned TypeWidth = SecondShift->getType()->getScalarSizeInBits(); + + // If the 2nd shift is bigger than the 1st, we can fold: + // lshr(c1) + shl(c2) -> shl(c3) + and(c4) or + // shl(c1) + lshr(c2) -> lshr(c3) + and(c4), + // but it isn't profitable unless we know the and'd out bits are already zero. + // Also check that the 2nd shift is valid (less than the type width) or we'll + // crash trying to produce the bit mask for the 'and'. + if (SecondShiftAmt > FirstShiftAmt && SecondShiftAmt < TypeWidth) { + unsigned MaskShift = IsSecondShiftLeft ? TypeWidth - SecondShiftAmt + : SecondShiftAmt - FirstShiftAmt; + APInt Mask = APInt::getLowBitsSet(TypeWidth, FirstShiftAmt) << MaskShift; + if (IC.MaskedValueIsZero(SecondShift->getOperand(0), Mask, 0, CxtI)) + return true; + } + + return false; +} + /// See if we can compute the specified value, but shifted /// logically to the left or right by some number of bits. This should return /// true if the expression can be computed for the same cost as the current @@ -67,7 +112,7 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { /// where the client will ask if E can be computed shifted right by 64-bits. If /// this succeeds, the GetShiftedValue function will be called to produce the /// value. -static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, +static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, InstCombiner &IC, Instruction *CxtI) { // We can always evaluate constants shifted. if (isa<Constant>(V)) @@ -81,8 +126,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, // the value which means that we don't care if the shift has multiple uses. // TODO: Handle opposite shift by exact value. ConstantInt *CI = nullptr; - if ((isLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) || - (!isLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) { + if ((IsLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) || + (!IsLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) { if (CI->getZExtValue() == NumBits) { // TODO: Check that the input bits are already zero with MaskedValueIsZero #if 0 @@ -111,64 +156,19 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, case Instruction::Or: case Instruction::Xor: // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted. - return CanEvaluateShifted(I->getOperand(0), NumBits, isLeftShift, IC, I) && - CanEvaluateShifted(I->getOperand(1), NumBits, isLeftShift, IC, I); - - case Instruction::Shl: { - // We can often fold the shift into shifts-by-a-constant. - CI = dyn_cast<ConstantInt>(I->getOperand(1)); - if (!CI) return false; - - // We can always fold shl(c1)+shl(c2) -> shl(c1+c2). - if (isLeftShift) return true; + return CanEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) && + CanEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I); - // We can always turn shl(c)+shr(c) -> and(c2). - if (CI->getValue() == NumBits) return true; + case Instruction::Shl: + case Instruction::LShr: + return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI); - unsigned TypeWidth = I->getType()->getScalarSizeInBits(); - - // We can turn shl(c1)+shr(c2) -> shl(c3)+and(c4), but it isn't - // profitable unless we know the and'd out bits are already zero. - if (CI->getZExtValue() > NumBits) { - unsigned LowBits = TypeWidth - CI->getZExtValue(); - if (IC.MaskedValueIsZero(I->getOperand(0), - APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits, - 0, CxtI)) - return true; - } - - return false; - } - case Instruction::LShr: { - // We can often fold the shift into shifts-by-a-constant. - CI = dyn_cast<ConstantInt>(I->getOperand(1)); - if (!CI) return false; - - // We can always fold lshr(c1)+lshr(c2) -> lshr(c1+c2). - if (!isLeftShift) return true; - - // We can always turn lshr(c)+shl(c) -> and(c2). - if (CI->getValue() == NumBits) return true; - - unsigned TypeWidth = I->getType()->getScalarSizeInBits(); - - // We can always turn lshr(c1)+shl(c2) -> lshr(c3)+and(c4), but it isn't - // profitable unless we know the and'd out bits are already zero. - if (CI->getValue().ult(TypeWidth) && CI->getZExtValue() > NumBits) { - unsigned LowBits = CI->getZExtValue() - NumBits; - if (IC.MaskedValueIsZero(I->getOperand(0), - APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits, - 0, CxtI)) - return true; - } - - return false; - } case Instruction::Select: { SelectInst *SI = cast<SelectInst>(I); - return CanEvaluateShifted(SI->getTrueValue(), NumBits, isLeftShift, - IC, SI) && - CanEvaluateShifted(SI->getFalseValue(), NumBits, isLeftShift, IC, SI); + Value *TrueVal = SI->getTrueValue(); + Value *FalseVal = SI->getFalseValue(); + return CanEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) && + CanEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI); } case Instruction::PHI: { // We can change a phi if we can change all operands. Note that we never @@ -176,8 +176,7 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, // instructions with a single use. PHINode *PN = cast<PHINode>(I); for (Value *IncValue : PN->incoming_values()) - if (!CanEvaluateShifted(IncValue, NumBits, isLeftShift, - IC, PN)) + if (!CanEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN)) return false; return true; } @@ -257,6 +256,8 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, BO->setHasNoSignedWrap(false); return BO; } + // FIXME: This is almost identical to the SHL case. Refactor both cases into + // a helper function. case Instruction::LShr: { BinaryOperator *BO = cast<BinaryOperator>(I); unsigned TypeWidth = BO->getType()->getScalarSizeInBits(); @@ -340,7 +341,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression" " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n"); - return ReplaceInstUsesWith( + return replaceInstUsesWith( I, GetShiftedValue(Op0, COp1->getZExtValue(), isLeftShift, *this, DL)); } @@ -356,7 +357,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, if (BO->getOpcode() == Instruction::Mul && isLeftShift) if (Constant *BOOp = dyn_cast<Constant>(BO->getOperand(1))) return BinaryOperator::CreateMul(BO->getOperand(0), - ConstantExpr::getShl(BOOp, Op1)); + ConstantExpr::getShl(BOOp, Op1)); // Try to fold constant and into select arguments. if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) @@ -573,7 +574,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // saturates. if (AmtSum >= TypeBits) { if (I.getOpcode() != Instruction::AShr) - return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + return replaceInstUsesWith(I, Constant::getNullValue(I.getType())); AmtSum = TypeBits-1; // Saturate to 31 for i32 ashr. } @@ -694,12 +695,12 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, Instruction *InstCombiner::visitShl(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Instruction *V = commonShiftTransforms(I)) return V; @@ -710,11 +711,11 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { // If the shifted-out value is known-zero, then this is a NUW shift. if (!I.hasNoUnsignedWrap() && MaskedValueIsZero(I.getOperand(0), - APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt), - 0, &I)) { - I.setHasNoUnsignedWrap(); - return &I; - } + APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt), 0, + &I)) { + I.setHasNoUnsignedWrap(); + return &I; + } // If the shifted out value is all signbits, this is a NSW shift. if (!I.hasNoSignedWrap() && @@ -736,11 +737,11 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { Instruction *InstCombiner::visitLShr(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) return R; @@ -780,11 +781,11 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { Instruction *InstCombiner::visitAShr(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) return R; @@ -813,8 +814,8 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && - MaskedValueIsZero(Op0,APInt::getLowBitsSet(Op1C->getBitWidth(),ShAmt), - 0, &I)){ + MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt), + 0, &I)) { I.setIsExact(); return &I; } diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 743d51483ea16..f3268d2c34714 100644 --- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -22,10 +22,9 @@ using namespace llvm::PatternMatch; #define DEBUG_TYPE "instcombine" -/// ShrinkDemandedConstant - Check to see if the specified operand of the -/// specified instruction is a constant integer. If so, check to see if there -/// are any bits set in the constant that are not demanded. If so, shrink the -/// constant and return true. +/// Check to see if the specified operand of the specified instruction is a +/// constant integer. If so, check to see if there are any bits set in the +/// constant that are not demanded. If so, shrink the constant and return true. static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, APInt Demanded) { assert(I && "No instruction?"); @@ -49,9 +48,8 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, -/// SimplifyDemandedInstructionBits - Inst is an integer instruction that -/// SimplifyDemandedBits knows about. See if the instruction has any -/// properties that allow us to simplify its operands. +/// Inst is an integer instruction that SimplifyDemandedBits knows about. See if +/// the instruction has any properties that allow us to simplify its operands. bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) { unsigned BitWidth = Inst.getType()->getScalarSizeInBits(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); @@ -61,14 +59,14 @@ bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) { 0, &Inst); if (!V) return false; if (V == &Inst) return true; - ReplaceInstUsesWith(Inst, V); + replaceInstUsesWith(Inst, V); return true; } -/// SimplifyDemandedBits - This form of SimplifyDemandedBits simplifies the -/// specified instruction operand if possible, updating it in place. It returns -/// true if it made any change and false otherwise. -bool InstCombiner::SimplifyDemandedBits(Use &U, APInt DemandedMask, +/// This form of SimplifyDemandedBits simplifies the specified instruction +/// operand if possible, updating it in place. It returns true if it made any +/// change and false otherwise. +bool InstCombiner::SimplifyDemandedBits(Use &U, const APInt &DemandedMask, APInt &KnownZero, APInt &KnownOne, unsigned Depth) { auto *UserI = dyn_cast<Instruction>(U.getUser()); @@ -80,21 +78,22 @@ bool InstCombiner::SimplifyDemandedBits(Use &U, APInt DemandedMask, } -/// SimplifyDemandedUseBits - This function attempts to replace V with a simpler -/// value based on the demanded bits. When this function is called, it is known -/// that only the bits set in DemandedMask of the result of V are ever used -/// downstream. Consequently, depending on the mask and V, it may be possible -/// to replace V with a constant or one of its operands. In such cases, this -/// function does the replacement and returns true. In all other cases, it -/// returns false after analyzing the expression and setting KnownOne and known -/// to be one in the expression. KnownZero contains all the bits that are known -/// to be zero in the expression. These are provided to potentially allow the -/// caller (which might recursively be SimplifyDemandedBits itself) to simplify -/// the expression. KnownOne and KnownZero always follow the invariant that -/// KnownOne & KnownZero == 0. That is, a bit can't be both 1 and 0. Note that -/// the bits in KnownOne and KnownZero may only be accurate for those bits set -/// in DemandedMask. Note also that the bitwidth of V, DemandedMask, KnownZero -/// and KnownOne must all be the same. +/// This function attempts to replace V with a simpler value based on the +/// demanded bits. When this function is called, it is known that only the bits +/// set in DemandedMask of the result of V are ever used downstream. +/// Consequently, depending on the mask and V, it may be possible to replace V +/// with a constant or one of its operands. In such cases, this function does +/// the replacement and returns true. In all other cases, it returns false after +/// analyzing the expression and setting KnownOne and known to be one in the +/// expression. KnownZero contains all the bits that are known to be zero in the +/// expression. These are provided to potentially allow the caller (which might +/// recursively be SimplifyDemandedBits itself) to simplify the expression. +/// KnownOne and KnownZero always follow the invariant that: +/// KnownOne & KnownZero == 0. +/// That is, a bit can't be both 1 and 0. Note that the bits in KnownOne and +/// KnownZero may only be accurate for those bits set in DemandedMask. Note also +/// that the bitwidth of V, DemandedMask, KnownZero and KnownOne must all be the +/// same. /// /// This returns null if it did not change anything and it permits no /// simplification. This returns V itself if it did some simplification of V's @@ -768,6 +767,34 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // TODO: Could compute known zero/one bits based on the input. break; } + case Intrinsic::x86_mmx_pmovmskb: + case Intrinsic::x86_sse_movmsk_ps: + case Intrinsic::x86_sse2_movmsk_pd: + case Intrinsic::x86_sse2_pmovmskb_128: + case Intrinsic::x86_avx_movmsk_ps_256: + case Intrinsic::x86_avx_movmsk_pd_256: + case Intrinsic::x86_avx2_pmovmskb: { + // MOVMSK copies the vector elements' sign bits to the low bits + // and zeros the high bits. + unsigned ArgWidth; + if (II->getIntrinsicID() == Intrinsic::x86_mmx_pmovmskb) { + ArgWidth = 8; // Arg is x86_mmx, but treated as <8 x i8>. + } else { + auto Arg = II->getArgOperand(0); + auto ArgType = cast<VectorType>(Arg->getType()); + ArgWidth = ArgType->getNumElements(); + } + + // If we don't need any of low bits then return zero, + // we know that DemandedMask is non-zero already. + APInt DemandedElts = DemandedMask.zextOrTrunc(ArgWidth); + if (DemandedElts == 0) + return ConstantInt::getNullValue(VTy); + + // We know that the upper bits are set to zero. + KnownZero = APInt::getHighBitsSet(BitWidth, BitWidth - ArgWidth); + return nullptr; + } case Intrinsic::x86_sse42_crc32_64_64: KnownZero = APInt::getHighBitsSet(64, 32); return nullptr; @@ -802,7 +829,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, /// As with SimplifyDemandedUseBits, it returns NULL if the simplification was /// not successful. Value *InstCombiner::SimplifyShrShlDemandedBits(Instruction *Shr, - Instruction *Shl, APInt DemandedMask, APInt &KnownZero, APInt &KnownOne) { + Instruction *Shl, + const APInt &DemandedMask, + APInt &KnownZero, + APInt &KnownOne) { const APInt &ShlOp1 = cast<ConstantInt>(Shl->getOperand(1))->getValue(); const APInt &ShrOp1 = cast<ConstantInt>(Shr->getOperand(1))->getValue(); @@ -865,10 +895,10 @@ Value *InstCombiner::SimplifyShrShlDemandedBits(Instruction *Shr, return nullptr; } -/// SimplifyDemandedVectorElts - The specified value produces a vector with -/// any number of elements. DemandedElts contains the set of elements that are -/// actually used by the caller. This method analyzes which elements of the -/// operand are undef and returns that information in UndefElts. +/// The specified value produces a vector with any number of elements. +/// DemandedElts contains the set of elements that are actually used by the +/// caller. This method analyzes which elements of the operand are undef and +/// returns that information in UndefElts. /// /// If the information about demanded elements can be used to simplify the /// operation, the operation is simplified, then the resultant value is @@ -876,7 +906,7 @@ Value *InstCombiner::SimplifyShrShlDemandedBits(Instruction *Shr, Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts, unsigned Depth) { - unsigned VWidth = cast<VectorType>(V->getType())->getNumElements(); + unsigned VWidth = V->getType()->getVectorNumElements(); APInt EltMask(APInt::getAllOnesValue(VWidth)); assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!"); @@ -1179,16 +1209,42 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, switch (II->getIntrinsicID()) { default: break; - // Binary vector operations that work column-wise. A dest element is a - // function of the corresponding input elements from the two inputs. + // Unary scalar-as-vector operations that work column-wise. + case Intrinsic::x86_sse_rcp_ss: + case Intrinsic::x86_sse_rsqrt_ss: + case Intrinsic::x86_sse_sqrt_ss: + case Intrinsic::x86_sse2_sqrt_sd: + case Intrinsic::x86_xop_vfrcz_ss: + case Intrinsic::x86_xop_vfrcz_sd: + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, + UndefElts, Depth + 1); + if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } + + // If lowest element of a scalar op isn't used then use Arg0. + if (DemandedElts.getLoBits(1) != 1) + return II->getArgOperand(0); + // TODO: If only low elt lower SQRT to FSQRT (with rounding/exceptions + // checks). + break; + + // Binary scalar-as-vector operations that work column-wise. A dest element + // is a function of the corresponding input elements from the two inputs. + case Intrinsic::x86_sse_add_ss: case Intrinsic::x86_sse_sub_ss: case Intrinsic::x86_sse_mul_ss: + case Intrinsic::x86_sse_div_ss: case Intrinsic::x86_sse_min_ss: case Intrinsic::x86_sse_max_ss: + case Intrinsic::x86_sse_cmp_ss: + case Intrinsic::x86_sse2_add_sd: case Intrinsic::x86_sse2_sub_sd: case Intrinsic::x86_sse2_mul_sd: + case Intrinsic::x86_sse2_div_sd: case Intrinsic::x86_sse2_min_sd: case Intrinsic::x86_sse2_max_sd: + case Intrinsic::x86_sse2_cmp_sd: + case Intrinsic::x86_sse41_round_ss: + case Intrinsic::x86_sse41_round_sd: TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, UndefElts, Depth + 1); if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } @@ -1201,11 +1257,15 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, if (DemandedElts == 1) { switch (II->getIntrinsicID()) { default: break; + case Intrinsic::x86_sse_add_ss: case Intrinsic::x86_sse_sub_ss: case Intrinsic::x86_sse_mul_ss: + case Intrinsic::x86_sse_div_ss: + case Intrinsic::x86_sse2_add_sd: case Intrinsic::x86_sse2_sub_sd: case Intrinsic::x86_sse2_mul_sd: - // TODO: Lower MIN/MAX/ABS/etc + case Intrinsic::x86_sse2_div_sd: + // TODO: Lower MIN/MAX/etc. Value *LHS = II->getArgOperand(0); Value *RHS = II->getArgOperand(1); // Extract the element as scalars. @@ -1216,6 +1276,11 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, switch (II->getIntrinsicID()) { default: llvm_unreachable("Case stmts out of sync!"); + case Intrinsic::x86_sse_add_ss: + case Intrinsic::x86_sse2_add_sd: + TmpV = InsertNewInstWith(BinaryOperator::CreateFAdd(LHS, RHS, + II->getName()), *II); + break; case Intrinsic::x86_sse_sub_ss: case Intrinsic::x86_sse2_sub_sd: TmpV = InsertNewInstWith(BinaryOperator::CreateFSub(LHS, RHS, @@ -1226,6 +1291,11 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, TmpV = InsertNewInstWith(BinaryOperator::CreateFMul(LHS, RHS, II->getName()), *II); break; + case Intrinsic::x86_sse_div_ss: + case Intrinsic::x86_sse2_div_sd: + TmpV = InsertNewInstWith(BinaryOperator::CreateFDiv(LHS, RHS, + II->getName()), *II); + break; } Instruction *New = @@ -1238,6 +1308,10 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, } } + // If lowest element of a scalar op isn't used then use Arg0. + if (DemandedElts.getLoBits(1) != 1) + return II->getArgOperand(0); + // Output elements are undefined if both are undefined. Consider things // like undef&0. The result is known zero, not undef. UndefElts &= UndefElts2; diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index bc4c0ebae7903..a761387561480 100644 --- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -62,21 +62,31 @@ static bool cheapToScalarize(Value *V, bool isConstant) { return false; } -// If we have a PHI node with a vector type that has only 2 uses: feed +// If we have a PHI node with a vector type that is only used to feed // itself and be an operand of extractelement at a constant location, // try to replace the PHI of the vector type with a PHI of a scalar type. Instruction *InstCombiner::scalarizePHI(ExtractElementInst &EI, PHINode *PN) { - // Verify that the PHI node has exactly 2 uses. Otherwise return NULL. - if (!PN->hasNUses(2)) - return nullptr; + SmallVector<Instruction *, 2> Extracts; + // The users we want the PHI to have are: + // 1) The EI ExtractElement (we already know this) + // 2) Possibly more ExtractElements with the same index. + // 3) Another operand, which will feed back into the PHI. + Instruction *PHIUser = nullptr; + for (auto U : PN->users()) { + if (ExtractElementInst *EU = dyn_cast<ExtractElementInst>(U)) { + if (EI.getIndexOperand() == EU->getIndexOperand()) + Extracts.push_back(EU); + else + return nullptr; + } else if (!PHIUser) { + PHIUser = cast<Instruction>(U); + } else { + return nullptr; + } + } - // If so, it's known at this point that one operand is PHI and the other is - // an extractelement node. Find the PHI user that is not the extractelement - // node. - auto iu = PN->user_begin(); - Instruction *PHIUser = dyn_cast<Instruction>(*iu); - if (PHIUser == cast<Instruction>(&EI)) - PHIUser = cast<Instruction>(*(++iu)); + if (!PHIUser) + return nullptr; // Verify that this PHI user has one use, which is the PHI itself, // and that it is a binary operation which is cheap to scalarize. @@ -106,7 +116,8 @@ Instruction *InstCombiner::scalarizePHI(ExtractElementInst &EI, PHINode *PN) { B0->getOperand(opId)->getName() + ".Elt"), *B0); Value *newPHIUser = InsertNewInstWith( - BinaryOperator::Create(B0->getOpcode(), scalarPHI, Op), *B0); + BinaryOperator::CreateWithCopiedFlags(B0->getOpcode(), + scalarPHI, Op, B0), *B0); scalarPHI->addIncoming(newPHIUser, inBB); } else { // Scalarize PHI input: @@ -125,19 +136,23 @@ Instruction *InstCombiner::scalarizePHI(ExtractElementInst &EI, PHINode *PN) { scalarPHI->addIncoming(newEI, inBB); } } - return ReplaceInstUsesWith(EI, scalarPHI); + + for (auto E : Extracts) + replaceInstUsesWith(*E, scalarPHI); + + return &EI; } Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { if (Value *V = SimplifyExtractElementInst( EI.getVectorOperand(), EI.getIndexOperand(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(EI, V); + return replaceInstUsesWith(EI, V); // If vector val is constant with all elements the same, replace EI with // that element. We handle a known element # below. if (Constant *C = dyn_cast<Constant>(EI.getOperand(0))) if (cheapToScalarize(C, false)) - return ReplaceInstUsesWith(EI, C->getAggregateElement(0U)); + return replaceInstUsesWith(EI, C->getAggregateElement(0U)); // If extracting a specified index from the vector, see if we can recursively // find a previously computed scalar that was inserted into the vector. @@ -193,12 +208,13 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { Value *newEI1 = Builder->CreateExtractElement(BO->getOperand(1), EI.getOperand(1), EI.getName()+".rhs"); - return BinaryOperator::Create(BO->getOpcode(), newEI0, newEI1); + return BinaryOperator::CreateWithCopiedFlags(BO->getOpcode(), + newEI0, newEI1, BO); } } else if (InsertElementInst *IE = dyn_cast<InsertElementInst>(I)) { // Extracting the inserted element? if (IE->getOperand(2) == EI.getOperand(1)) - return ReplaceInstUsesWith(EI, IE->getOperand(1)); + return replaceInstUsesWith(EI, IE->getOperand(1)); // If the inserted and extracted elements are constants, they must not // be the same value, extract from the pre-inserted value instead. if (isa<Constant>(IE->getOperand(2)) && isa<Constant>(EI.getOperand(1))) { @@ -216,7 +232,7 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { SVI->getOperand(0)->getType()->getVectorNumElements(); if (SrcIdx < 0) - return ReplaceInstUsesWith(EI, UndefValue::get(EI.getType())); + return replaceInstUsesWith(EI, UndefValue::get(EI.getType())); if (SrcIdx < (int)LHSWidth) Src = SVI->getOperand(0); else { @@ -417,7 +433,7 @@ static void replaceExtractElements(InsertElementInst *InsElt, continue; auto *NewExt = ExtractElementInst::Create(WideVec, OldExt->getOperand(1)); NewExt->insertAfter(WideVec); - IC.ReplaceInstUsesWith(*OldExt, NewExt); + IC.replaceInstUsesWith(*OldExt, NewExt); } } @@ -546,7 +562,7 @@ Instruction *InstCombiner::visitInsertValueInst(InsertValueInst &I) { } if (IsRedundant) - return ReplaceInstUsesWith(I, I.getOperand(0)); + return replaceInstUsesWith(I, I.getOperand(0)); return nullptr; } @@ -557,7 +573,7 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { // Inserting an undef or into an undefined place, remove this. if (isa<UndefValue>(ScalarOp) || isa<UndefValue>(IdxOp)) - ReplaceInstUsesWith(IE, VecOp); + replaceInstUsesWith(IE, VecOp); // If the inserted element was extracted from some other vector, and if the // indexes are constant, try to turn this into a shufflevector operation. @@ -571,15 +587,15 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { unsigned InsertedIdx = cast<ConstantInt>(IdxOp)->getZExtValue(); if (ExtractedIdx >= NumExtractVectorElts) // Out of range extract. - return ReplaceInstUsesWith(IE, VecOp); + return replaceInstUsesWith(IE, VecOp); if (InsertedIdx >= NumInsertVectorElts) // Out of range insert. - return ReplaceInstUsesWith(IE, UndefValue::get(IE.getType())); + return replaceInstUsesWith(IE, UndefValue::get(IE.getType())); // If we are extracting a value from a vector, then inserting it right // back into the same place, just use the input vector. if (EI->getOperand(0) == VecOp && ExtractedIdx == InsertedIdx) - return ReplaceInstUsesWith(IE, VecOp); + return replaceInstUsesWith(IE, VecOp); // If this insertelement isn't used by some other insertelement, turn it // (and any insertelements it points to), into one big shuffle. @@ -605,7 +621,7 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); if (Value *V = SimplifyDemandedVectorElts(&IE, AllOnesEltMask, UndefElts)) { if (V != &IE) - return ReplaceInstUsesWith(IE, V); + return replaceInstUsesWith(IE, V); return &IE; } @@ -910,7 +926,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { // Undefined shuffle mask -> undefined value. if (isa<UndefValue>(SVI.getOperand(2))) - return ReplaceInstUsesWith(SVI, UndefValue::get(SVI.getType())); + return replaceInstUsesWith(SVI, UndefValue::get(SVI.getType())); unsigned VWidth = cast<VectorType>(SVI.getType())->getNumElements(); @@ -918,7 +934,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); if (Value *V = SimplifyDemandedVectorElts(&SVI, AllOnesEltMask, UndefElts)) { if (V != &SVI) - return ReplaceInstUsesWith(SVI, V); + return replaceInstUsesWith(SVI, V); LHS = SVI.getOperand(0); RHS = SVI.getOperand(1); MadeChange = true; @@ -933,7 +949,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { // shuffle(undef,undef,mask) -> undef. Value *Result = (VWidth == LHSWidth) ? LHS : UndefValue::get(SVI.getType()); - return ReplaceInstUsesWith(SVI, Result); + return replaceInstUsesWith(SVI, Result); } // Remap any references to RHS to use LHS. @@ -967,13 +983,13 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { recognizeIdentityMask(Mask, isLHSID, isRHSID); // Eliminate identity shuffles. - if (isLHSID) return ReplaceInstUsesWith(SVI, LHS); - if (isRHSID) return ReplaceInstUsesWith(SVI, RHS); + if (isLHSID) return replaceInstUsesWith(SVI, LHS); + if (isRHSID) return replaceInstUsesWith(SVI, RHS); } if (isa<UndefValue>(RHS) && CanEvaluateShuffled(LHS, Mask)) { Value *V = EvaluateInDifferentElementOrder(LHS, Mask); - return ReplaceInstUsesWith(SVI, V); + return replaceInstUsesWith(SVI, V); } // SROA generates shuffle+bitcast when the extracted sub-vector is bitcast to @@ -1060,7 +1076,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { NewBC, ConstantInt::get(Int32Ty, BegIdx), SVI.getName() + ".extract"); // The shufflevector isn't being replaced: the bitcast that used it // is. InstCombine will visit the newly-created instructions. - ReplaceInstUsesWith(*BC, Ext); + replaceInstUsesWith(*BC, Ext); MadeChange = true; } } @@ -1251,8 +1267,8 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { // corresponding argument. bool isLHSID, isRHSID; recognizeIdentityMask(newMask, isLHSID, isRHSID); - if (isLHSID && VWidth == LHSOp0Width) return ReplaceInstUsesWith(SVI, newLHS); - if (isRHSID && VWidth == RHSOp0Width) return ReplaceInstUsesWith(SVI, newRHS); + if (isLHSID && VWidth == LHSOp0Width) return replaceInstUsesWith(SVI, newLHS); + if (isRHSID && VWidth == RHSOp0Width) return replaceInstUsesWith(SVI, newRHS); return MadeChange ? &SVI : nullptr; } diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp index 903a0b5f5400a..51c3262b5d14f 100644 --- a/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -39,7 +39,9 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/EHPersonalities.h" @@ -76,6 +78,10 @@ STATISTIC(NumExpand, "Number of expansions"); STATISTIC(NumFactor , "Number of factorizations"); STATISTIC(NumReassoc , "Number of reassociations"); +static cl::opt<bool> +EnableExpensiveCombines("expensive-combines", + cl::desc("Enable expensive instruction combines")); + Value *InstCombiner::EmitGEPOffset(User *GEP) { return llvm::EmitGEPOffset(Builder, DL, GEP); } @@ -120,33 +126,23 @@ bool InstCombiner::ShouldChangeType(Type *From, Type *To) const { // all other opcodes, the function conservatively returns false. static bool MaintainNoSignedWrap(BinaryOperator &I, Value *B, Value *C) { OverflowingBinaryOperator *OBO = dyn_cast<OverflowingBinaryOperator>(&I); - if (!OBO || !OBO->hasNoSignedWrap()) { + if (!OBO || !OBO->hasNoSignedWrap()) return false; - } // We reason about Add and Sub Only. Instruction::BinaryOps Opcode = I.getOpcode(); - if (Opcode != Instruction::Add && - Opcode != Instruction::Sub) { + if (Opcode != Instruction::Add && Opcode != Instruction::Sub) return false; - } - - ConstantInt *CB = dyn_cast<ConstantInt>(B); - ConstantInt *CC = dyn_cast<ConstantInt>(C); - if (!CB || !CC) { + const APInt *BVal, *CVal; + if (!match(B, m_APInt(BVal)) || !match(C, m_APInt(CVal))) return false; - } - const APInt &BVal = CB->getValue(); - const APInt &CVal = CC->getValue(); bool Overflow = false; - - if (Opcode == Instruction::Add) { - BVal.sadd_ov(CVal, Overflow); - } else { - BVal.ssub_ov(CVal, Overflow); - } + if (Opcode == Instruction::Add) + BVal->sadd_ov(*CVal, Overflow); + else + BVal->ssub_ov(*CVal, Overflow); return !Overflow; } @@ -166,6 +162,49 @@ static void ClearSubclassDataAfterReassociation(BinaryOperator &I) { I.setFastMathFlags(FMF); } +/// Combine constant operands of associative operations either before or after a +/// cast to eliminate one of the associative operations: +/// (op (cast (op X, C2)), C1) --> (cast (op X, op (C1, C2))) +/// (op (cast (op X, C2)), C1) --> (op (cast X), op (C1, C2)) +static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1) { + auto *Cast = dyn_cast<CastInst>(BinOp1->getOperand(0)); + if (!Cast || !Cast->hasOneUse()) + return false; + + // TODO: Enhance logic for other casts and remove this check. + auto CastOpcode = Cast->getOpcode(); + if (CastOpcode != Instruction::ZExt) + return false; + + // TODO: Enhance logic for other BinOps and remove this check. + auto AssocOpcode = BinOp1->getOpcode(); + if (AssocOpcode != Instruction::Xor && AssocOpcode != Instruction::And && + AssocOpcode != Instruction::Or) + return false; + + auto *BinOp2 = dyn_cast<BinaryOperator>(Cast->getOperand(0)); + if (!BinOp2 || !BinOp2->hasOneUse() || BinOp2->getOpcode() != AssocOpcode) + return false; + + Constant *C1, *C2; + if (!match(BinOp1->getOperand(1), m_Constant(C1)) || + !match(BinOp2->getOperand(1), m_Constant(C2))) + return false; + + // TODO: This assumes a zext cast. + // Eg, if it was a trunc, we'd cast C1 to the source type because casting C2 + // to the destination type might lose bits. + + // Fold the constants together in the destination type: + // (op (cast (op X, C2)), C1) --> (op (cast X), FoldedC) + Type *DestTy = C1->getType(); + Constant *CastC2 = ConstantExpr::getCast(CastOpcode, C2, DestTy); + Constant *FoldedC = ConstantExpr::get(AssocOpcode, C1, CastC2); + Cast->setOperand(0, BinOp2->getOperand(0)); + BinOp1->setOperand(1, FoldedC); + return true; +} + /// This performs a few simplifications for operators that are associative or /// commutative: /// @@ -253,6 +292,12 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { } if (I.isAssociative() && I.isCommutative()) { + if (simplifyAssocCastAssoc(&I)) { + Changed = true; + ++NumReassoc; + continue; + } + // Transform: "(A op B) op C" ==> "(C op A) op B" if "C op A" simplifies. if (Op0 && Op0->getOpcode() == Opcode) { Value *A = Op0->getOperand(0); @@ -919,10 +964,10 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { for (auto UI = PN->user_begin(), E = PN->user_end(); UI != E;) { Instruction *User = cast<Instruction>(*UI++); if (User == &I) continue; - ReplaceInstUsesWith(*User, NewPN); - EraseInstFromFunction(*User); + replaceInstUsesWith(*User, NewPN); + eraseInstFromFunction(*User); } - return ReplaceInstUsesWith(I, NewPN); + return replaceInstUsesWith(I, NewPN); } /// Given a pointer type and a constant offset, determine whether or not there @@ -1334,8 +1379,8 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) { Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end()); - if (Value *V = SimplifyGEPInst(Ops, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(GEP, V); + if (Value *V = SimplifyGEPInst(GEP.getSourceElementType(), Ops, DL, TLI, DT, AC)) + return replaceInstUsesWith(GEP, V); Value *PtrOp = GEP.getOperand(0); @@ -1349,19 +1394,18 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { for (User::op_iterator I = GEP.op_begin() + 1, E = GEP.op_end(); I != E; ++I, ++GTI) { // Skip indices into struct types. - SequentialType *SeqTy = dyn_cast<SequentialType>(*GTI); - if (!SeqTy) + if (isa<StructType>(*GTI)) continue; // Index type should have the same width as IntPtr Type *IndexTy = (*I)->getType(); Type *NewIndexType = IndexTy->isVectorTy() ? VectorType::get(IntPtrTy, IndexTy->getVectorNumElements()) : IntPtrTy; - + // If the element type has zero size then any index over it is equivalent // to an index of zero, so replace it with zero if it is not zero already. - if (SeqTy->getElementType()->isSized() && - DL.getTypeAllocSize(SeqTy->getElementType()) == 0) + Type *EltTy = GTI.getIndexedType(); + if (EltTy->isSized() && DL.getTypeAllocSize(EltTy) == 0) if (!isa<Constant>(*I) || !cast<Constant>(*I)->isNullValue()) { *I = Constant::getNullValue(NewIndexType); MadeChange = true; @@ -1393,7 +1437,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (Op1 == &GEP) return nullptr; - signed DI = -1; + int DI = -1; for (auto I = PN->op_begin()+1, E = PN->op_end(); I !=E; ++I) { GetElementPtrInst *Op2 = dyn_cast<GetElementPtrInst>(*I); @@ -1405,7 +1449,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { return nullptr; // Keep track of the type as we walk the GEP. - Type *CurTy = Op1->getOperand(0)->getType()->getScalarType(); + Type *CurTy = nullptr; for (unsigned J = 0, F = Op1->getNumOperands(); J != F; ++J) { if (Op1->getOperand(J)->getType() != Op2->getOperand(J)->getType()) @@ -1436,7 +1480,9 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // Sink down a layer of the type for the next iteration. if (J > 0) { - if (CompositeType *CT = dyn_cast<CompositeType>(CurTy)) { + if (J == 1) { + CurTy = Op1->getSourceElementType(); + } else if (CompositeType *CT = dyn_cast<CompositeType>(CurTy)) { CurTy = CT->getTypeAtIndex(Op1->getOperand(J)); } else { CurTy = nullptr; @@ -1565,8 +1611,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { unsigned AS = GEP.getPointerAddressSpace(); if (GEP.getOperand(1)->getType()->getScalarSizeInBits() == DL.getPointerSizeInBits(AS)) { - Type *PtrTy = GEP.getPointerOperandType(); - Type *Ty = PtrTy->getPointerElementType(); + Type *Ty = GEP.getSourceElementType(); uint64_t TyAllocSize = DL.getTypeAllocSize(Ty); bool Matched = false; @@ -1629,9 +1674,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // // This occurs when the program declares an array extern like "int X[];" if (HasZeroPointerIndex) { - PointerType *CPTy = cast<PointerType>(PtrOp->getType()); if (ArrayType *CATy = - dyn_cast<ArrayType>(CPTy->getElementType())) { + dyn_cast<ArrayType>(GEP.getSourceElementType())) { // GEP (bitcast i8* X to [0 x i8]*), i32 0, ... ? if (CATy->getElementType() == StrippedPtrTy->getElementType()) { // -> GEP i8* X, ... @@ -1688,7 +1732,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // %t = getelementptr i32* bitcast ([2 x i32]* %str to i32*), i32 %V // into: %t1 = getelementptr [2 x i32]* %str, i32 0, i32 %V; bitcast Type *SrcElTy = StrippedPtrTy->getElementType(); - Type *ResElTy = PtrOp->getType()->getPointerElementType(); + Type *ResElTy = GEP.getSourceElementType(); if (SrcElTy->isArrayTy() && DL.getTypeAllocSize(SrcElTy->getArrayElementType()) == DL.getTypeAllocSize(ResElTy)) { @@ -1822,7 +1866,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (I != BCI) { I->takeName(BCI); BCI->getParent()->getInstList().insert(BCI->getIterator(), I); - ReplaceInstUsesWith(*BCI, I); + replaceInstUsesWith(*BCI, I); } return &GEP; } @@ -1844,7 +1888,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { : Builder->CreateGEP(nullptr, Operand, NewIndices); if (NGEP->getType() == GEP.getType()) - return ReplaceInstUsesWith(GEP, NGEP); + return replaceInstUsesWith(GEP, NGEP); NGEP->takeName(&GEP); if (NGEP->getType()->getPointerAddressSpace() != GEP.getAddressSpace()) @@ -1857,6 +1901,20 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { return nullptr; } +static bool isNeverEqualToUnescapedAlloc(Value *V, const TargetLibraryInfo *TLI, + Instruction *AI) { + if (isa<ConstantPointerNull>(V)) + return true; + if (auto *LI = dyn_cast<LoadInst>(V)) + return isa<GlobalVariable>(LI->getPointerOperand()); + // Two distinct allocations will never be equal. + // We rely on LookThroughBitCast in isAllocLikeFn being false, since looking + // through bitcasts of V can cause + // the result statement below to be true, even when AI and V (ex: + // i8* ->i32* ->i8* of AI) are the same allocations. + return isAllocLikeFn(V, TLI) && V != AI; +} + static bool isAllocSiteRemovable(Instruction *AI, SmallVectorImpl<WeakVH> &Users, const TargetLibraryInfo *TLI) { @@ -1881,7 +1939,12 @@ isAllocSiteRemovable(Instruction *AI, SmallVectorImpl<WeakVH> &Users, case Instruction::ICmp: { ICmpInst *ICI = cast<ICmpInst>(I); // We can fold eq/ne comparisons with null to false/true, respectively. - if (!ICI->isEquality() || !isa<ConstantPointerNull>(ICI->getOperand(1))) + // We also fold comparisons in some conditions provided the alloc has + // not escaped (see isNeverEqualToUnescapedAlloc). + if (!ICI->isEquality()) + return false; + unsigned OtherIndex = (ICI->getOperand(0) == PI) ? 1 : 0; + if (!isNeverEqualToUnescapedAlloc(ICI->getOperand(OtherIndex), TLI, AI)) return false; Users.emplace_back(I); continue; @@ -1941,23 +2004,40 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { SmallVector<WeakVH, 64> Users; if (isAllocSiteRemovable(&MI, Users, TLI)) { for (unsigned i = 0, e = Users.size(); i != e; ++i) { - Instruction *I = cast_or_null<Instruction>(&*Users[i]); - if (!I) continue; + // Lowering all @llvm.objectsize calls first because they may + // use a bitcast/GEP of the alloca we are removing. + if (!Users[i]) + continue; + + Instruction *I = cast<Instruction>(&*Users[i]); + + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + if (II->getIntrinsicID() == Intrinsic::objectsize) { + uint64_t Size; + if (!getObjectSize(II->getArgOperand(0), Size, DL, TLI)) { + ConstantInt *CI = cast<ConstantInt>(II->getArgOperand(1)); + Size = CI->isZero() ? -1ULL : 0; + } + replaceInstUsesWith(*I, ConstantInt::get(I->getType(), Size)); + eraseInstFromFunction(*I); + Users[i] = nullptr; // Skip examining in the next loop. + } + } + } + for (unsigned i = 0, e = Users.size(); i != e; ++i) { + if (!Users[i]) + continue; + + Instruction *I = cast<Instruction>(&*Users[i]); if (ICmpInst *C = dyn_cast<ICmpInst>(I)) { - ReplaceInstUsesWith(*C, + replaceInstUsesWith(*C, ConstantInt::get(Type::getInt1Ty(C->getContext()), C->isFalseWhenEqual())); } else if (isa<BitCastInst>(I) || isa<GetElementPtrInst>(I)) { - ReplaceInstUsesWith(*I, UndefValue::get(I->getType())); - } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { - if (II->getIntrinsicID() == Intrinsic::objectsize) { - ConstantInt *CI = cast<ConstantInt>(II->getArgOperand(1)); - uint64_t DontKnow = CI->isZero() ? -1ULL : 0; - ReplaceInstUsesWith(*I, ConstantInt::get(I->getType(), DontKnow)); - } + replaceInstUsesWith(*I, UndefValue::get(I->getType())); } - EraseInstFromFunction(*I); + eraseInstFromFunction(*I); } if (InvokeInst *II = dyn_cast<InvokeInst>(&MI)) { @@ -1967,7 +2047,7 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { InvokeInst::Create(F, II->getNormalDest(), II->getUnwindDest(), None, "", II->getParent()); } - return EraseInstFromFunction(MI); + return eraseInstFromFunction(MI); } return nullptr; } @@ -2038,13 +2118,13 @@ Instruction *InstCombiner::visitFree(CallInst &FI) { // Insert a new store to null because we cannot modify the CFG here. Builder->CreateStore(ConstantInt::getTrue(FI.getContext()), UndefValue::get(Type::getInt1PtrTy(FI.getContext()))); - return EraseInstFromFunction(FI); + return eraseInstFromFunction(FI); } // If we have 'free null' delete the instruction. This can happen in stl code // when lots of inlining happens. if (isa<ConstantPointerNull>(Op)) - return EraseInstFromFunction(FI); + return eraseInstFromFunction(FI); // If we optimize for code size, try to move the call to free before the null // test so that simplify cfg can remove the empty block and dead code @@ -2145,6 +2225,7 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { unsigned LeadingKnownOnes = KnownOne.countLeadingOnes(); // Compute the number of leading bits we can ignore. + // TODO: A better way to determine this would use ComputeNumSignBits(). for (auto &C : SI.cases()) { LeadingKnownZeros = std::min( LeadingKnownZeros, C.getCaseValue()->getValue().countLeadingZeros()); @@ -2154,17 +2235,15 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { unsigned NewWidth = BitWidth - std::max(LeadingKnownZeros, LeadingKnownOnes); - // Truncate the condition operand if the new type is equal to or larger than - // the largest legal integer type. We need to be conservative here since - // x86 generates redundant zero-extension instructions if the operand is - // truncated to i8 or i16. + // Shrink the condition operand if the new type is smaller than the old type. + // This may produce a non-standard type for the switch, but that's ok because + // the backend should extend back to a legal type for the target. bool TruncCond = false; - if (NewWidth > 0 && BitWidth > NewWidth && - NewWidth >= DL.getLargestLegalIntTypeSize()) { + if (NewWidth > 0 && NewWidth < BitWidth) { TruncCond = true; IntegerType *Ty = IntegerType::get(SI.getContext(), NewWidth); Builder->SetInsertPoint(&SI); - Value *NewCond = Builder->CreateTrunc(SI.getCondition(), Ty, "trunc"); + Value *NewCond = Builder->CreateTrunc(Cond, Ty, "trunc"); SI.setCondition(NewCond); for (auto &C : SI.cases()) @@ -2172,28 +2251,27 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { SI.getContext(), C.getCaseValue()->getValue().trunc(NewWidth))); } - if (Instruction *I = dyn_cast<Instruction>(Cond)) { - if (I->getOpcode() == Instruction::Add) - if (ConstantInt *AddRHS = dyn_cast<ConstantInt>(I->getOperand(1))) { - // change 'switch (X+4) case 1:' into 'switch (X) case -3' - // Skip the first item since that's the default case. - for (SwitchInst::CaseIt i = SI.case_begin(), e = SI.case_end(); - i != e; ++i) { - ConstantInt* CaseVal = i.getCaseValue(); - Constant *LHS = CaseVal; - if (TruncCond) - LHS = LeadingKnownZeros - ? ConstantExpr::getZExt(CaseVal, Cond->getType()) - : ConstantExpr::getSExt(CaseVal, Cond->getType()); - Constant* NewCaseVal = ConstantExpr::getSub(LHS, AddRHS); - assert(isa<ConstantInt>(NewCaseVal) && - "Result of expression should be constant"); - i.setValue(cast<ConstantInt>(NewCaseVal)); - } - SI.setCondition(I->getOperand(0)); - Worklist.Add(I); - return &SI; + ConstantInt *AddRHS = nullptr; + if (match(Cond, m_Add(m_Value(), m_ConstantInt(AddRHS)))) { + Instruction *I = cast<Instruction>(Cond); + // Change 'switch (X+4) case 1:' into 'switch (X) case -3'. + for (SwitchInst::CaseIt i = SI.case_begin(), e = SI.case_end(); i != e; + ++i) { + ConstantInt *CaseVal = i.getCaseValue(); + Constant *LHS = CaseVal; + if (TruncCond) { + LHS = LeadingKnownZeros + ? ConstantExpr::getZExt(CaseVal, Cond->getType()) + : ConstantExpr::getSExt(CaseVal, Cond->getType()); } + Constant *NewCaseVal = ConstantExpr::getSub(LHS, AddRHS); + assert(isa<ConstantInt>(NewCaseVal) && + "Result of expression should be constant"); + i.setValue(cast<ConstantInt>(NewCaseVal)); + } + SI.setCondition(I->getOperand(0)); + Worklist.Add(I); + return &SI; } return TruncCond ? &SI : nullptr; @@ -2203,11 +2281,11 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { Value *Agg = EV.getAggregateOperand(); if (!EV.hasIndices()) - return ReplaceInstUsesWith(EV, Agg); + return replaceInstUsesWith(EV, Agg); if (Value *V = SimplifyExtractValueInst(Agg, EV.getIndices(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(EV, V); + return replaceInstUsesWith(EV, V); if (InsertValueInst *IV = dyn_cast<InsertValueInst>(Agg)) { // We're extracting from an insertvalue instruction, compare the indices @@ -2233,7 +2311,7 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { // %B = insertvalue { i32, { i32 } } %A, i32 42, 1, 0 // %C = extractvalue { i32, { i32 } } %B, 1, 0 // with "i32 42" - return ReplaceInstUsesWith(EV, IV->getInsertedValueOperand()); + return replaceInstUsesWith(EV, IV->getInsertedValueOperand()); if (exti == exte) { // The extract list is a prefix of the insert list. i.e. replace // %I = insertvalue { i32, { i32 } } %A, i32 42, 1, 0 @@ -2273,8 +2351,8 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { case Intrinsic::sadd_with_overflow: if (*EV.idx_begin() == 0) { // Normal result. Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); - ReplaceInstUsesWith(*II, UndefValue::get(II->getType())); - EraseInstFromFunction(*II); + replaceInstUsesWith(*II, UndefValue::get(II->getType())); + eraseInstFromFunction(*II); return BinaryOperator::CreateAdd(LHS, RHS); } @@ -2290,8 +2368,8 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { case Intrinsic::ssub_with_overflow: if (*EV.idx_begin() == 0) { // Normal result. Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); - ReplaceInstUsesWith(*II, UndefValue::get(II->getType())); - EraseInstFromFunction(*II); + replaceInstUsesWith(*II, UndefValue::get(II->getType())); + eraseInstFromFunction(*II); return BinaryOperator::CreateSub(LHS, RHS); } break; @@ -2299,8 +2377,8 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { case Intrinsic::smul_with_overflow: if (*EV.idx_begin() == 0) { // Normal result. Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); - ReplaceInstUsesWith(*II, UndefValue::get(II->getType())); - EraseInstFromFunction(*II); + replaceInstUsesWith(*II, UndefValue::get(II->getType())); + eraseInstFromFunction(*II); return BinaryOperator::CreateMul(LHS, RHS); } break; @@ -2330,8 +2408,8 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { Value *GEP = Builder->CreateInBoundsGEP(L->getType(), L->getPointerOperand(), Indices); // Returning the load directly will cause the main loop to insert it in - // the wrong spot, so use ReplaceInstUsesWith(). - return ReplaceInstUsesWith(EV, Builder->CreateLoad(GEP)); + // the wrong spot, so use replaceInstUsesWith(). + return replaceInstUsesWith(EV, Builder->CreateLoad(GEP)); } // We could simplify extracts from other values. Note that nested extracts may // already be simplified implicitly by the above: extract (extract (insert) ) @@ -2348,8 +2426,10 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { static bool isCatchAll(EHPersonality Personality, Constant *TypeInfo) { switch (Personality) { case EHPersonality::GNU_C: - // The GCC C EH personality only exists to support cleanups, so it's not - // clear what the semantics of catch clauses are. + case EHPersonality::GNU_C_SjLj: + case EHPersonality::Rust: + // The GCC C EH and Rust personality only exists to support cleanups, so + // it's not clear what the semantics of catch clauses are. return false; case EHPersonality::Unknown: return false; @@ -2358,6 +2438,7 @@ static bool isCatchAll(EHPersonality Personality, Constant *TypeInfo) { // match foreign exceptions (or didn't, before gcc-4.7). return false; case EHPersonality::GNU_CXX: + case EHPersonality::GNU_CXX_SjLj: case EHPersonality::GNU_ObjC: case EHPersonality::MSVC_X86SEH: case EHPersonality::MSVC_Win64SEH: @@ -2701,12 +2782,15 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { &DestBlock->getParent()->getEntryBlock()) return false; + // Do not sink into catchswitch blocks. + if (isa<CatchSwitchInst>(DestBlock->getTerminator())) + return false; + // Do not sink convergent call instructions. if (auto *CI = dyn_cast<CallInst>(I)) { if (CI->isConvergent()) return false; } - // We can only sink load instructions if there is nothing between the load and // the end of block that could change the value. if (I->mayReadFromMemory()) { @@ -2731,7 +2815,7 @@ bool InstCombiner::run() { // Check to see if we can DCE the instruction. if (isInstructionTriviallyDead(I, TLI)) { DEBUG(dbgs() << "IC: DCE: " << *I << '\n'); - EraseInstFromFunction(*I); + eraseInstFromFunction(*I); ++NumDeadInst; MadeIRChange = true; continue; @@ -2744,17 +2828,17 @@ bool InstCombiner::run() { DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << *I << '\n'); // Add operands to the worklist. - ReplaceInstUsesWith(*I, C); + replaceInstUsesWith(*I, C); ++NumConstProp; - EraseInstFromFunction(*I); + eraseInstFromFunction(*I); MadeIRChange = true; continue; } } - // In general, it is possible for computeKnownBits to determine all bits in a - // value even when the operands are not all constants. - if (!I->use_empty() && I->getType()->isIntegerTy()) { + // In general, it is possible for computeKnownBits to determine all bits in + // a value even when the operands are not all constants. + if (ExpensiveCombines && !I->use_empty() && I->getType()->isIntegerTy()) { unsigned BitWidth = I->getType()->getScalarSizeInBits(); APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); @@ -2765,9 +2849,9 @@ bool InstCombiner::run() { " from: " << *I << '\n'); // Add operands to the worklist. - ReplaceInstUsesWith(*I, C); + replaceInstUsesWith(*I, C); ++NumConstProp; - EraseInstFromFunction(*I); + eraseInstFromFunction(*I); MadeIRChange = true; continue; } @@ -2800,6 +2884,7 @@ bool InstCombiner::run() { if (UserIsSuccessor && UserParent->getSinglePredecessor()) { // Okay, the CFG is simple enough, try to sink this instruction. if (TryToSinkInstruction(I, UserParent)) { + DEBUG(dbgs() << "IC: Sink: " << *I << '\n'); MadeIRChange = true; // We'll add uses of the sunk instruction below, but since sinking // can expose opportunities for it's *operands* add them to the @@ -2852,7 +2937,7 @@ bool InstCombiner::run() { InstParent->getInstList().insert(InsertPos, Result); - EraseInstFromFunction(*I); + eraseInstFromFunction(*I); } else { #ifndef NDEBUG DEBUG(dbgs() << "IC: Mod = " << OrigI << '\n' @@ -2862,7 +2947,7 @@ bool InstCombiner::run() { // If the instruction was modified, it's possible that it is now dead. // if so, remove it. if (isInstructionTriviallyDead(I, TLI)) { - EraseInstFromFunction(*I); + eraseInstFromFunction(*I); } else { Worklist.Add(I); Worklist.AddUsersToWorkList(*I); @@ -3002,35 +3087,20 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, // Do a depth-first traversal of the function, populate the worklist with // the reachable instructions. Ignore blocks that are not reachable. Keep // track of which blocks we visit. - SmallPtrSet<BasicBlock *, 64> Visited; + SmallPtrSet<BasicBlock *, 32> Visited; MadeIRChange |= AddReachableCodeToWorklist(&F.front(), DL, Visited, ICWorklist, TLI); // Do a quick scan over the function. If we find any blocks that are // unreachable, remove any instructions inside of them. This prevents // the instcombine code from having to deal with some bad special cases. - for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { - if (Visited.count(&*BB)) + for (BasicBlock &BB : F) { + if (Visited.count(&BB)) continue; - // Delete the instructions backwards, as it has a reduced likelihood of - // having to update as many def-use and use-def chains. - Instruction *EndInst = BB->getTerminator(); // Last not to be deleted. - while (EndInst != BB->begin()) { - // Delete the next to last instruction. - Instruction *Inst = &*--EndInst->getIterator(); - if (!Inst->use_empty() && !Inst->getType()->isTokenTy()) - Inst->replaceAllUsesWith(UndefValue::get(Inst->getType())); - if (Inst->isEHPad() || Inst->getType()->isTokenTy()) { - EndInst = Inst; - continue; - } - if (!isa<DbgInfoIntrinsic>(Inst)) { - ++NumDeadInst; - MadeIRChange = true; - } - Inst->eraseFromParent(); - } + unsigned NumDeadInstInBB = removeAllNonTerminatorAndEHPadInstructions(&BB); + MadeIRChange |= NumDeadInstInBB > 0; + NumDeadInst += NumDeadInstInBB; } return MadeIRChange; @@ -3040,12 +3110,14 @@ static bool combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist, AliasAnalysis *AA, AssumptionCache &AC, TargetLibraryInfo &TLI, DominatorTree &DT, + bool ExpensiveCombines = true, LoopInfo *LI = nullptr) { auto &DL = F.getParent()->getDataLayout(); + ExpensiveCombines |= EnableExpensiveCombines; /// Builder - This is an IRBuilder that automatically inserts new /// instructions into the worklist when they are created. - IRBuilder<true, TargetFolder, InstCombineIRInserter> Builder( + IRBuilder<TargetFolder, InstCombineIRInserter> Builder( F.getContext(), TargetFolder(DL), InstCombineIRInserter(Worklist, &AC)); // Lower dbg.declare intrinsics otherwise their value may be clobbered @@ -3059,14 +3131,11 @@ combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist, DEBUG(dbgs() << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on " << F.getName() << "\n"); - bool Changed = false; - if (prepareICWorklistFromFunction(F, DL, &TLI, Worklist)) - Changed = true; + bool Changed = prepareICWorklistFromFunction(F, DL, &TLI, Worklist); - InstCombiner IC(Worklist, &Builder, F.optForMinSize(), + InstCombiner IC(Worklist, &Builder, F.optForMinSize(), ExpensiveCombines, AA, &AC, &TLI, &DT, DL, LI); - if (IC.run()) - Changed = true; + Changed |= IC.run(); if (!Changed) break; @@ -3076,45 +3145,26 @@ combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist, } PreservedAnalyses InstCombinePass::run(Function &F, - AnalysisManager<Function> *AM) { - auto &AC = AM->getResult<AssumptionAnalysis>(F); - auto &DT = AM->getResult<DominatorTreeAnalysis>(F); - auto &TLI = AM->getResult<TargetLibraryAnalysis>(F); + AnalysisManager<Function> &AM) { + auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); - auto *LI = AM->getCachedResult<LoopAnalysis>(F); + auto *LI = AM.getCachedResult<LoopAnalysis>(F); // FIXME: The AliasAnalysis is not yet supported in the new pass manager - if (!combineInstructionsOverFunction(F, Worklist, nullptr, AC, TLI, DT, LI)) + if (!combineInstructionsOverFunction(F, Worklist, nullptr, AC, TLI, DT, + ExpensiveCombines, LI)) // No changes, all analyses are preserved. return PreservedAnalyses::all(); // Mark all the analyses that instcombine updates as preserved. - // FIXME: Need a way to preserve CFG analyses here! + // FIXME: This should also 'preserve the CFG'. PreservedAnalyses PA; PA.preserve<DominatorTreeAnalysis>(); return PA; } -namespace { -/// \brief The legacy pass manager's instcombine pass. -/// -/// This is a basic whole-function wrapper around the instcombine utility. It -/// will try to combine all instructions in the function. -class InstructionCombiningPass : public FunctionPass { - InstCombineWorklist Worklist; - -public: - static char ID; // Pass identification, replacement for typeid - - InstructionCombiningPass() : FunctionPass(ID) { - initializeInstructionCombiningPassPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override; - bool runOnFunction(Function &F) override; -}; -} - void InstructionCombiningPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesCFG(); AU.addRequired<AAResultsWrapperPass>(); @@ -3122,11 +3172,13 @@ void InstructionCombiningPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<AAResultsWrapperPass>(); + AU.addPreserved<BasicAAWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); } bool InstructionCombiningPass::runOnFunction(Function &F) { - if (skipOptnoneFunction(F)) + if (skipFunction(F)) return false; // Required analyses. @@ -3139,7 +3191,8 @@ bool InstructionCombiningPass::runOnFunction(Function &F) { auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; - return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, LI); + return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, + ExpensiveCombines, LI); } char InstructionCombiningPass::ID = 0; @@ -3162,6 +3215,6 @@ void LLVMInitializeInstCombine(LLVMPassRegistryRef R) { initializeInstructionCombiningPassPass(*unwrap(R)); } -FunctionPass *llvm::createInstructionCombiningPass() { - return new InstructionCombiningPass(); +FunctionPass *llvm::createInstructionCombiningPass(bool ExpensiveCombines) { + return new InstructionCombiningPass(ExpensiveCombines); } diff --git a/lib/Transforms/InstCombine/Makefile b/lib/Transforms/InstCombine/Makefile deleted file mode 100644 index 0c488e78b6d92..0000000000000 --- a/lib/Transforms/InstCombine/Makefile +++ /dev/null @@ -1,15 +0,0 @@ -##===- lib/Transforms/InstCombine/Makefile -----------------*- Makefile -*-===## -# -# The LLVM Compiler Infrastructure -# -# This file is distributed under the University of Illinois Open Source -# License. See LICENSE.TXT for details. -# -##===----------------------------------------------------------------------===## - -LEVEL = ../../.. -LIBRARYNAME = LLVMInstCombine -BUILD_ARCHIVE = 1 - -include $(LEVEL)/Makefile.common - |