diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2017-12-18 20:10:56 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2017-12-18 20:10:56 +0000 |
commit | 044eb2f6afba375a914ac9d8024f8f5142bb912e (patch) | |
tree | 1475247dc9f9fe5be155ebd4c9069c75aadf8c20 /lib/Transforms/InstCombine | |
parent | eb70dddbd77e120e5d490bd8fbe7ff3f8fa81c6b (diff) |
Notes
Diffstat (limited to 'lib/Transforms/InstCombine')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineAddSub.cpp | 320 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 614 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCalls.cpp | 174 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCasts.cpp | 206 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCompares.cpp | 737 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineInternal.h | 142 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp | 90 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineMulDivRem.cpp | 186 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombinePHI.cpp | 259 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineSelect.cpp | 506 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineShifts.cpp | 151 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp | 94 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineVectorOps.cpp | 111 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstructionCombining.cpp | 188 |
14 files changed, 2220 insertions, 1558 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 809471cfd74f..688897644848 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -12,12 +12,26 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/AlignOf.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/KnownBits.h" +#include <cassert> +#include <utility> using namespace llvm; using namespace PatternMatch; @@ -39,10 +53,15 @@ namespace { // is expensive. In order to avoid the cost of the constructor, we should // reuse some instances whenever possible. The pre-created instances // FAddCombine::Add[0-5] embodies this idea. - // - FAddendCoef() : IsFp(false), BufHasFpVal(false), IntVal(0) {} + FAddendCoef() = default; ~FAddendCoef(); + // If possible, don't define operator+/operator- etc because these + // operators inevitably call FAddendCoef's constructor which is not cheap. + void operator=(const FAddendCoef &A); + void operator+=(const FAddendCoef &A); + void operator*=(const FAddendCoef &S); + void set(short C) { assert(!insaneIntVal(C) && "Insane coefficient"); IsFp = false; IntVal = C; @@ -55,12 +74,6 @@ namespace { bool isZero() const { return isInt() ? !IntVal : getFpVal().isZero(); } Value *getValue(Type *) const; - // If possible, don't define operator+/operator- etc because these - // operators inevitably call FAddendCoef's constructor which is not cheap. - void operator=(const FAddendCoef &A); - void operator+=(const FAddendCoef &A); - void operator*=(const FAddendCoef &S); - bool isOne() const { return isInt() && IntVal == 1; } bool isTwo() const { return isInt() && IntVal == 2; } bool isMinusOne() const { return isInt() && IntVal == -1; } @@ -68,10 +81,12 @@ namespace { private: bool insaneIntVal(int V) { return V > 4 || V < -4; } + APFloat *getFpValPtr() - { return reinterpret_cast<APFloat*>(&FpValBuf.buffer[0]); } + { return reinterpret_cast<APFloat *>(&FpValBuf.buffer[0]); } + const APFloat *getFpValPtr() const - { return reinterpret_cast<const APFloat*>(&FpValBuf.buffer[0]); } + { return reinterpret_cast<const APFloat *>(&FpValBuf.buffer[0]); } const APFloat &getFpVal() const { assert(IsFp && BufHasFpVal && "Incorret state"); @@ -94,17 +109,16 @@ namespace { // from an *SIGNED* integer. APFloat createAPFloatFromInt(const fltSemantics &Sem, int Val); - private: - bool IsFp; + bool IsFp = false; // True iff FpValBuf contains an instance of APFloat. - bool BufHasFpVal; + bool BufHasFpVal = false; // The integer coefficient of an individual addend is either 1 or -1, // and we try to simplify at most 4 addends from neighboring at most // two instructions. So the range of <IntVal> falls in [-4, 4]. APInt // is overkill of this end. - short IntVal; + short IntVal = 0; AlignedCharArrayUnion<APFloat> FpValBuf; }; @@ -112,10 +126,14 @@ namespace { /// FAddend is used to represent floating-point addend. An addend is /// represented as <C, V>, where the V is a symbolic value, and C is a /// constant coefficient. A constant addend is represented as <C, 0>. - /// class FAddend { public: - FAddend() : Val(nullptr) {} + FAddend() = default; + + void operator+=(const FAddend &T) { + assert((Val == T.Val) && "Symbolic-values disagree"); + Coeff += T.Coeff; + } Value *getSymVal() const { return Val; } const FAddendCoef &getCoef() const { return Coeff; } @@ -146,16 +164,11 @@ namespace { /// splitted is the addend itself. unsigned drillAddendDownOneStep(FAddend &Addend0, FAddend &Addend1) const; - void operator+=(const FAddend &T) { - assert((Val == T.Val) && "Symbolic-values disagree"); - Coeff += T.Coeff; - } - private: void Scale(const FAddendCoef& ScaleAmt) { Coeff *= ScaleAmt; } // This addend has the value of "Coeff * Val". - Value *Val; + Value *Val = nullptr; FAddendCoef Coeff; }; @@ -164,11 +177,12 @@ namespace { /// class FAddCombine { public: - FAddCombine(InstCombiner::BuilderTy &B) : Builder(B), Instr(nullptr) {} + FAddCombine(InstCombiner::BuilderTy &B) : Builder(B) {} + Value *simplify(Instruction *FAdd); private: - typedef SmallVector<const FAddend*, 4> AddendVect; + using AddendVect = SmallVector<const FAddend *, 4>; Value *simplifyFAdd(AddendVect& V, unsigned InstrQuota); @@ -179,6 +193,7 @@ namespace { /// Return the number of instructions needed to emit the N-ary addition. unsigned calcInstrNumber(const AddendVect& Vect); + Value *createFSub(Value *Opnd0, Value *Opnd1); Value *createFAdd(Value *Opnd0, Value *Opnd1); Value *createFMul(Value *Opnd0, Value *Opnd1); @@ -187,9 +202,6 @@ namespace { Value *createNaryFAdd(const AddendVect& Opnds, unsigned InstrQuota); void createInstPostProc(Instruction *NewInst, bool NoNumber = false); - InstCombiner::BuilderTy &Builder; - Instruction *Instr; - // Debugging stuff are clustered here. #ifndef NDEBUG unsigned CreateInstrNum; @@ -199,9 +211,12 @@ namespace { void initCreateInstNum() {} void incCreateInstNum() {} #endif + + InstCombiner::BuilderTy &Builder; + Instruction *Instr = nullptr; }; -} // anonymous namespace +} // end anonymous namespace //===----------------------------------------------------------------------===// // @@ -332,7 +347,6 @@ Value *FAddendCoef::getValue(Type *Ty) const { // 0 +/- 0 <0, NULL> (corner case) // // Legend: A and B are not constant, C is constant -// unsigned FAddend::drillValueDownOneStep (Value *Val, FAddend &Addend0, FAddend &Addend1) { Instruction *I = nullptr; @@ -396,7 +410,6 @@ unsigned FAddend::drillValueDownOneStep // Try to break *this* addend into two addends. e.g. Suppose this addend is // <2.3, V>, and V = X + Y, by calling this function, we obtain two addends, // i.e. <2.3, X> and <2.3, Y>. -// unsigned FAddend::drillAddendDownOneStep (FAddend &Addend0, FAddend &Addend1) const { if (isConstant()) @@ -421,7 +434,6 @@ unsigned FAddend::drillAddendDownOneStep // ------------------------------------------------------- // (x * y) +/- (x * z) x * (y +/- z) // (y / x) +/- (z / x) (y +/- z) / x -// Value *FAddCombine::performFactorization(Instruction *I) { assert((I->getOpcode() == Instruction::FAdd || I->getOpcode() == Instruction::FSub) && "Expect add/sub"); @@ -447,7 +459,6 @@ Value *FAddCombine::performFactorization(Instruction *I) { // ---------------------------------------------- // (x*y) +/- (x*z) x y z // (y/x) +/- (z/x) x y z - // Value *Factor = nullptr; Value *AddSub0 = nullptr, *AddSub1 = nullptr; @@ -471,7 +482,7 @@ Value *FAddCombine::performFactorization(Instruction *I) { return nullptr; FastMathFlags Flags; - Flags.setUnsafeAlgebra(); + Flags.setFast(); if (I0) Flags &= I->getFastMathFlags(); if (I1) Flags &= I->getFastMathFlags(); @@ -500,7 +511,7 @@ Value *FAddCombine::performFactorization(Instruction *I) { } Value *FAddCombine::simplify(Instruction *I) { - assert(I->hasUnsafeAlgebra() && "Should be in unsafe mode"); + assert(I->isFast() && "Expected 'fast' instruction"); // Currently we are not able to handle vector type. if (I->getType()->isVectorTy()) @@ -599,7 +610,6 @@ Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) { // desirable to reside at the top of the resulting expression tree. Placing // constant close to supper-expr(s) will potentially reveal some optimization // opportunities in super-expr(s). - // const FAddend *ConstAdd = nullptr; // Simplified addends are placed <SimpVect>. @@ -608,7 +618,6 @@ Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) { // The outer loop works on one symbolic-value at a time. Suppose the input // addends are : <a1, x>, <b1, y>, <a2, x>, <c1, z>, <b2, y>, ... // The symbolic-values will be processed in this order: x, y, z. - // for (unsigned SymIdx = 0; SymIdx < AddendNum; SymIdx++) { const FAddend *ThisAddend = Addends[SymIdx]; @@ -626,7 +635,6 @@ Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) { // example, if the symbolic value "y" is being processed, the inner loop // will collect two addends "<b1,y>" and "<b2,Y>". These two addends will // be later on folded into "<b1+b2, y>". - // for (unsigned SameSymIdx = SymIdx + 1; SameSymIdx < AddendNum; SameSymIdx++) { const FAddend *T = Addends[SameSymIdx]; @@ -681,7 +689,7 @@ Value *FAddCombine::createNaryFAdd assert(!Opnds.empty() && "Expect at least one addend"); // Step 1: Check if the # of instructions needed exceeds the quota. - // + unsigned InstrNeeded = calcInstrNumber(Opnds); if (InstrNeeded > InstrQuota) return nullptr; @@ -726,10 +734,10 @@ Value *FAddCombine::createNaryFAdd LastVal = createFNeg(LastVal); } - #ifndef NDEBUG - assert(CreateInstrNum == InstrNeeded && - "Inconsistent in instruction numbers"); - #endif +#ifndef NDEBUG + assert(CreateInstrNum == InstrNeeded && + "Inconsistent in instruction numbers"); +#endif return LastVal; } @@ -950,9 +958,25 @@ static Value *checkForNegativeOperand(BinaryOperator &I, return nullptr; } -static Instruction *foldAddWithConstant(BinaryOperator &Add, - InstCombiner::BuilderTy &Builder) { +Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) { Value *Op0 = Add.getOperand(0), *Op1 = Add.getOperand(1); + Constant *Op1C; + if (!match(Op1, m_Constant(Op1C))) + return nullptr; + + if (Instruction *NV = foldOpWithConstantIntoOperand(Add)) + return NV; + + Value *X; + // zext(bool) + C -> bool ? C + 1 : C + if (match(Op0, m_ZExt(m_Value(X))) && + X->getType()->getScalarSizeInBits() == 1) + return SelectInst::Create(X, AddOne(Op1C), Op1); + + // ~X + C --> (C-1) - X + if (match(Op0, m_Not(m_Value(X)))) + return BinaryOperator::CreateSub(SubOne(Op1C), X); + const APInt *C; if (!match(Op1, m_APInt(C))) return nullptr; @@ -968,21 +992,17 @@ static Instruction *foldAddWithConstant(BinaryOperator &Add, return BinaryOperator::CreateXor(Op0, Op1); } - Value *X; - const APInt *C2; - Type *Ty = Add.getType(); - // Is this add the last step in a convoluted sext? // add(zext(xor i16 X, -32768), -32768) --> sext X + Type *Ty = Add.getType(); + const APInt *C2; if (match(Op0, m_ZExt(m_Xor(m_Value(X), m_APInt(C2)))) && C2->isMinSignedValue() && C2->sext(Ty->getScalarSizeInBits()) == *C) return CastInst::Create(Instruction::SExt, X, Ty); // (add (zext (add nuw X, C2)), C) --> (zext (add nuw X, C2 + C)) - // FIXME: This should check hasOneUse to not increase the instruction count? - if (C->isNegative() && - match(Op0, m_ZExt(m_NUWAdd(m_Value(X), m_APInt(C2)))) && - C->sge(-C2->sext(C->getBitWidth()))) { + if (match(Op0, m_OneUse(m_ZExt(m_NUWAdd(m_Value(X), m_APInt(C2))))) && + C->isNegative() && C->sge(-C2->sext(C->getBitWidth()))) { Constant *NewC = ConstantInt::get(X->getType(), *C2 + C->trunc(C2->getBitWidth())); return new ZExtInst(Builder.CreateNUWAdd(X, NewC), Ty); @@ -1013,34 +1033,29 @@ static Instruction *foldAddWithConstant(BinaryOperator &Add, Instruction *InstCombiner::visitAdd(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); - Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); - if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - // (A*B)+(A*C) -> A*(B+C) etc + // (A*B)+(A*C) -> A*(B+C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); - if (Instruction *X = foldAddWithConstant(I, Builder)) + if (Instruction *X = foldAddWithConstant(I)) return X; // FIXME: This should be moved into the above helper function to allow these - // transforms for splat vectors. + // transforms for general constant or constant splat vectors. + Type *Ty = I.getType(); if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { - // zext(bool) + C -> bool ? C + 1 : C - if (ZExtInst *ZI = dyn_cast<ZExtInst>(LHS)) - if (ZI->getSrcTy()->isIntegerTy(1)) - return SelectInst::Create(ZI->getOperand(0), AddOne(CI), CI); - Value *XorLHS = nullptr; ConstantInt *XorRHS = nullptr; if (match(LHS, m_Xor(m_Value(XorLHS), m_ConstantInt(XorRHS)))) { - uint32_t TySizeBits = I.getType()->getScalarSizeInBits(); + unsigned TySizeBits = Ty->getScalarSizeInBits(); const APInt &RHSVal = CI->getValue(); unsigned ExtendAmt = 0; // If we have ADD(XOR(AND(X, 0xFF), 0x80), 0xF..F80), it's a sext. @@ -1059,7 +1074,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } if (ExtendAmt) { - Constant *ShAmt = ConstantInt::get(I.getType(), ExtendAmt); + Constant *ShAmt = ConstantInt::get(Ty, ExtendAmt); Value *NewShl = Builder.CreateShl(XorLHS, ShAmt, "sext"); return BinaryOperator::CreateAShr(NewShl, ShAmt); } @@ -1080,38 +1095,30 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } } - if (isa<Constant>(RHS)) - if (Instruction *NV = foldOpWithConstantIntoOperand(I)) - return NV; - - if (I.getType()->isIntOrIntVectorTy(1)) + if (Ty->isIntOrIntVectorTy(1)) return BinaryOperator::CreateXor(LHS, RHS); // X + X --> X << 1 if (LHS == RHS) { - BinaryOperator *New = - BinaryOperator::CreateShl(LHS, ConstantInt::get(I.getType(), 1)); - New->setHasNoSignedWrap(I.hasNoSignedWrap()); - New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); - return New; + auto *Shl = BinaryOperator::CreateShl(LHS, ConstantInt::get(Ty, 1)); + Shl->setHasNoSignedWrap(I.hasNoSignedWrap()); + Shl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + return Shl; } - // -A + B --> B - A - // -A + -B --> -(A + B) - if (Value *LHSV = dyn_castNegVal(LHS)) { - if (!isa<Constant>(RHS)) - if (Value *RHSV = dyn_castNegVal(RHS)) { - Value *NewAdd = Builder.CreateAdd(LHSV, RHSV, "sum"); - return BinaryOperator::CreateNeg(NewAdd); - } + Value *A, *B; + if (match(LHS, m_Neg(m_Value(A)))) { + // -A + -B --> -(A + B) + if (match(RHS, m_Neg(m_Value(B)))) + return BinaryOperator::CreateNeg(Builder.CreateAdd(A, B)); - return BinaryOperator::CreateSub(RHS, LHSV); + // -A + B --> B - A + return BinaryOperator::CreateSub(RHS, A); } // A + -B --> A - B - if (!isa<Constant>(RHS)) - if (Value *V = dyn_castNegVal(RHS)) - return BinaryOperator::CreateSub(LHS, V); + if (match(RHS, m_Neg(m_Value(B)))) + return BinaryOperator::CreateSub(LHS, B); if (Value *V = checkForNegativeOperand(I, Builder)) return replaceInstUsesWith(I, V); @@ -1120,12 +1127,6 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT)) return BinaryOperator::CreateOr(LHS, RHS); - if (Constant *CRHS = dyn_cast<Constant>(RHS)) { - Value *X; - if (match(LHS, m_Not(m_Value(X)))) // ~X + C --> (C-1) - X - return BinaryOperator::CreateSub(SubOne(CRHS), X); - } - // FIXME: We already did a check for ConstantInt RHS above this. // FIXME: Is this pattern covered by another fold? No regression tests fail on // removal. @@ -1187,12 +1188,12 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (LHSConv->hasOneUse()) { Constant *CI = ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); - if (ConstantExpr::getSExt(CI, I.getType()) == RHSC && + if (ConstantExpr::getSExt(CI, Ty) == RHSC && willNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) { // Insert the new, smaller add. Value *NewAdd = Builder.CreateNSWAdd(LHSConv->getOperand(0), CI, "addconv"); - return new SExtInst(NewAdd, I.getType()); + return new SExtInst(NewAdd, Ty); } } } @@ -1210,7 +1211,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // Insert the new integer add. Value *NewAdd = Builder.CreateNSWAdd(LHSConv->getOperand(0), RHSConv->getOperand(0), "addconv"); - return new SExtInst(NewAdd, I.getType()); + return new SExtInst(NewAdd, Ty); } } } @@ -1223,12 +1224,12 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (LHSConv->hasOneUse()) { Constant *CI = ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); - if (ConstantExpr::getZExt(CI, I.getType()) == RHSC && + if (ConstantExpr::getZExt(CI, Ty) == RHSC && willNotOverflowUnsignedAdd(LHSConv->getOperand(0), CI, I)) { // Insert the new, smaller add. Value *NewAdd = Builder.CreateNUWAdd(LHSConv->getOperand(0), CI, "addconv"); - return new ZExtInst(NewAdd, I.getType()); + return new ZExtInst(NewAdd, Ty); } } } @@ -1246,41 +1247,35 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // Insert the new integer add. Value *NewAdd = Builder.CreateNUWAdd( LHSConv->getOperand(0), RHSConv->getOperand(0), "addconv"); - return new ZExtInst(NewAdd, I.getType()); + return new ZExtInst(NewAdd, Ty); } } } // (add (xor A, B) (and A, B)) --> (or A, B) - { - Value *A = nullptr, *B = nullptr; - if (match(RHS, m_Xor(m_Value(A), m_Value(B))) && - match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) - return BinaryOperator::CreateOr(A, B); - - if (match(LHS, m_Xor(m_Value(A), m_Value(B))) && - match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) - return BinaryOperator::CreateOr(A, B); - } + if (match(LHS, m_Xor(m_Value(A), m_Value(B))) && + match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateOr(A, B); + + // (add (and A, B) (xor A, B)) --> (or A, B) + if (match(RHS, m_Xor(m_Value(A), m_Value(B))) && + match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateOr(A, B); // (add (or A, B) (and A, B)) --> (add A, B) - { - Value *A = nullptr, *B = nullptr; - if (match(RHS, m_Or(m_Value(A), m_Value(B))) && - match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) { - auto *New = BinaryOperator::CreateAdd(A, B); - New->setHasNoSignedWrap(I.hasNoSignedWrap()); - New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); - return New; - } + if (match(LHS, m_Or(m_Value(A), m_Value(B))) && + match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) { + I.setOperand(0, A); + I.setOperand(1, B); + return &I; + } - if (match(LHS, m_Or(m_Value(A), m_Value(B))) && - match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) { - auto *New = BinaryOperator::CreateAdd(A, B); - New->setHasNoSignedWrap(I.hasNoSignedWrap()); - New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); - return New; - } + // (add (and A, B) (or A, B)) --> (add A, B) + if (match(RHS, m_Or(m_Value(A), m_Value(B))) && + match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) { + I.setOperand(0, A); + I.setOperand(1, B); + return &I; } // TODO(jingyue): Consider willNotOverflowSignedAdd and @@ -1387,32 +1382,11 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { } } - // select C, 0, B + select C, A, 0 -> select C, A, B - { - Value *A1, *B1, *C1, *A2, *B2, *C2; - if (match(LHS, m_Select(m_Value(C1), m_Value(A1), m_Value(B1))) && - match(RHS, m_Select(m_Value(C2), m_Value(A2), m_Value(B2)))) { - if (C1 == C2) { - Constant *Z1=nullptr, *Z2=nullptr; - Value *A, *B, *C=C1; - if (match(A1, m_AnyZero()) && match(B2, m_AnyZero())) { - Z1 = dyn_cast<Constant>(A1); A = A2; - Z2 = dyn_cast<Constant>(B2); B = B1; - } else if (match(B1, m_AnyZero()) && match(A2, m_AnyZero())) { - Z1 = dyn_cast<Constant>(B1); B = B2; - Z2 = dyn_cast<Constant>(A2); A = A1; - } - - if (Z1 && Z2 && - (I.hasNoSignedZeros() || - (Z1->isNegativeZeroValue() && Z2->isNegativeZeroValue()))) { - return SelectInst::Create(C, A, B); - } - } - } - } + // Handle specials cases for FAdd with selects feeding the operation + if (Value *V = SimplifySelectsFeedingBinaryOp(I, LHS, RHS)) + return replaceInstUsesWith(I, V); - if (I.hasUnsafeAlgebra()) { + if (I.isFast()) { if (Value *V = FAddCombine(Builder).simplify(&I)) return replaceInstUsesWith(I, V); } @@ -1423,7 +1397,6 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { /// Optimize pointer differences into the same array into a size. Consider: /// &A[10] - &A[0]: we should compile this to "10". LHS/RHS are the pointer /// operands to the ptrtoint instructions for the LHS/RHS of the subtract. -/// Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS, Type *Ty) { // If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize @@ -1465,12 +1438,31 @@ Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS, } } - // Avoid duplicating the arithmetic if GEP2 has non-constant indices and - // multiple users. - if (!GEP1 || - (GEP2 && !GEP2->hasAllConstantIndices() && !GEP2->hasOneUse())) + if (!GEP1) + // No GEP found. return nullptr; + if (GEP2) { + // (gep X, ...) - (gep X, ...) + // + // Avoid duplicating the arithmetic if there are more than one non-constant + // indices between the two GEPs and either GEP has a non-constant index and + // multiple users. If zero non-constant index, the result is a constant and + // there is no duplication. If one non-constant index, the result is an add + // or sub with a constant, which is no larger than the original code, and + // there's no duplicated arithmetic, even if either GEP has multiple + // users. If more than one non-constant indices combined, as long as the GEP + // with at least one non-constant index doesn't have multiple users, there + // is no duplication. + unsigned NumNonConstantIndices1 = GEP1->countNonConstantIndices(); + unsigned NumNonConstantIndices2 = GEP2->countNonConstantIndices(); + if (NumNonConstantIndices1 + NumNonConstantIndices2 > 1 && + ((NumNonConstantIndices1 > 0 && !GEP1->hasOneUse()) || + (NumNonConstantIndices2 > 0 && !GEP2->hasOneUse()))) { + return nullptr; + } + } + // Emit the offset of the GEP and an intptr_t. Value *Result = EmitGEPOffset(GEP1); @@ -1528,8 +1520,13 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return BinaryOperator::CreateNot(Op1); if (Constant *C = dyn_cast<Constant>(Op0)) { + Value *X; + // C - zext(bool) -> bool ? C - 1 : C + if (match(Op1, m_ZExt(m_Value(X))) && + X->getType()->getScalarSizeInBits() == 1) + return SelectInst::Create(X, SubOne(C), C); + // C - ~X == X + (1+C) - Value *X = nullptr; if (match(Op1, m_Not(m_Value(X)))) return BinaryOperator::CreateAdd(X, AddOne(C)); @@ -1600,7 +1597,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return BinaryOperator::CreateNeg(Y); } - // (sub (or A, B) (xor A, B)) --> (and A, B) + // (sub (or A, B), (xor A, B)) --> (and A, B) { Value *A, *B; if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && @@ -1626,7 +1623,6 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Builder.CreateSub(Z, Y, Op1->getName())); // (X - (X & Y)) --> (X & ~Y) - // if (match(Op1, m_c_And(m_Value(Y), m_Specific(Op0)))) return BinaryOperator::CreateAnd(Op0, Builder.CreateNot(Y, Y->getName() + ".not")); @@ -1741,7 +1737,11 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { } } - if (I.hasUnsafeAlgebra()) { + // Handle specials cases for FSub with selects feeding the operation + if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) + return replaceInstUsesWith(I, V); + + if (I.isFast()) { if (Value *V = FAddCombine(Builder).simplify(&I)) return replaceInstUsesWith(I, V); } diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index fdc9c373b95e..2364202e5b69 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -12,11 +12,11 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/Transforms/Utils/CmpInstAnalysis.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; using namespace PatternMatch; @@ -120,35 +120,9 @@ Instruction *InstCombiner::OptAndOp(BinaryOperator *Op, ConstantInt *AndRHS, BinaryOperator &TheAnd) { Value *X = Op->getOperand(0); - Constant *Together = nullptr; - if (!Op->isShift()) - Together = ConstantExpr::getAnd(AndRHS, OpRHS); switch (Op->getOpcode()) { default: break; - case Instruction::Xor: - if (Op->hasOneUse()) { - // (X ^ C1) & C2 --> (X & C2) ^ (C1&C2) - Value *And = Builder.CreateAnd(X, AndRHS); - And->takeName(Op); - return BinaryOperator::CreateXor(And, Together); - } - break; - case Instruction::Or: - if (Op->hasOneUse()){ - ConstantInt *TogetherCI = dyn_cast<ConstantInt>(Together); - if (TogetherCI && !TogetherCI->isZero()){ - // (X | C1) & C2 --> (X & (C2^(C1&C2))) | C1 - // NOTE: This reduces the number of bits set in the & mask, which - // can expose opportunities for store narrowing. - Together = ConstantExpr::getXor(AndRHS, Together); - Value *And = Builder.CreateAnd(X, Together); - And->takeName(Op); - return BinaryOperator::CreateOr(And, OpRHS); - } - } - - break; case Instruction::Add: if (Op->hasOneUse()) { // Adding a one to a single bit bit-field should be turned into an XOR @@ -182,64 +156,6 @@ Instruction *InstCombiner::OptAndOp(BinaryOperator *Op, } } break; - - case Instruction::Shl: { - // We know that the AND will not produce any of the bits shifted in, so if - // the anded constant includes them, clear them now! - // - uint32_t BitWidth = AndRHS->getType()->getBitWidth(); - uint32_t OpRHSVal = OpRHS->getLimitedValue(BitWidth); - APInt ShlMask(APInt::getHighBitsSet(BitWidth, BitWidth-OpRHSVal)); - ConstantInt *CI = Builder.getInt(AndRHS->getValue() & ShlMask); - - if (CI->getValue() == ShlMask) - // Masking out bits that the shift already masks. - return replaceInstUsesWith(TheAnd, Op); // No need for the and. - - if (CI != AndRHS) { // Reducing bits set in and. - TheAnd.setOperand(1, CI); - return &TheAnd; - } - break; - } - case Instruction::LShr: { - // We know that the AND will not produce any of the bits shifted in, so if - // the anded constant includes them, clear them now! This only applies to - // unsigned shifts, because a signed shr may bring in set bits! - // - uint32_t BitWidth = AndRHS->getType()->getBitWidth(); - uint32_t OpRHSVal = OpRHS->getLimitedValue(BitWidth); - APInt ShrMask(APInt::getLowBitsSet(BitWidth, BitWidth - OpRHSVal)); - ConstantInt *CI = Builder.getInt(AndRHS->getValue() & ShrMask); - - if (CI->getValue() == ShrMask) - // Masking out bits that the shift already masks. - return replaceInstUsesWith(TheAnd, Op); - - if (CI != AndRHS) { - TheAnd.setOperand(1, CI); // Reduce bits set in and cst. - return &TheAnd; - } - break; - } - case Instruction::AShr: - // Signed shr. - // See if this is shifting in some sign extension, then masking it out - // with an and. - if (Op->hasOneUse()) { - uint32_t BitWidth = AndRHS->getType()->getBitWidth(); - uint32_t OpRHSVal = OpRHS->getLimitedValue(BitWidth); - APInt ShrMask(APInt::getLowBitsSet(BitWidth, BitWidth - OpRHSVal)); - Constant *C = Builder.getInt(AndRHS->getValue() & ShrMask); - if (C == AndRHS) { // Masking out bits shifted in. - // (Val ashr C1) & C2 -> (Val lshr C1) & C2 - // Make the argument unsigned. - Value *ShVal = Op->getOperand(0); - ShVal = Builder.CreateLShr(ShVal, OpRHS, Op->getName()); - return BinaryOperator::CreateAnd(ShVal, AndRHS, TheAnd.getName()); - } - } - break; } return nullptr; } @@ -376,6 +292,18 @@ static unsigned conjugateICmpMask(unsigned Mask) { return NewMask; } +// Adapts the external decomposeBitTestICmp for local use. +static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred, + Value *&X, Value *&Y, Value *&Z) { + APInt Mask; + if (!llvm::decomposeBitTestICmp(LHS, RHS, Pred, X, Mask)) + return false; + + Y = ConstantInt::get(X->getType(), Mask); + Z = ConstantInt::get(X->getType(), 0); + return true; +} + /// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E). /// Return the set of pattern classes (from MaskedICmpType) that both LHS and /// RHS satisfy. @@ -384,10 +312,9 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) { - if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) - return 0; - // vectors are not (yet?) supported - if (LHS->getOperand(0)->getType()->isVectorTy()) + // vectors are not (yet?) supported. Don't support pointers either. + if (!LHS->getOperand(0)->getType()->isIntegerTy() || + !RHS->getOperand(0)->getType()->isIntegerTy()) return 0; // Here comes the tricky part: @@ -400,24 +327,18 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, Value *L2 = LHS->getOperand(1); Value *L11, *L12, *L21, *L22; // Check whether the icmp can be decomposed into a bit test. - if (decomposeBitTestICmp(LHS, PredL, L11, L12, L2)) { + if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) { L21 = L22 = L1 = nullptr; } else { // Look for ANDs in the LHS icmp. - if (!L1->getType()->isIntegerTy()) { - // You can icmp pointers, for example. They really aren't masks. - L11 = L12 = nullptr; - } else if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) { + if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) { // Any icmp can be viewed as being trivially masked; if it allows us to // remove one, it's worth it. L11 = L1; L12 = Constant::getAllOnesValue(L1->getType()); } - if (!L2->getType()->isIntegerTy()) { - // You can icmp pointers, for example. They really aren't masks. - L21 = L22 = nullptr; - } else if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) { + if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) { L21 = L2; L22 = Constant::getAllOnesValue(L2->getType()); } @@ -431,7 +352,7 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, Value *R2 = RHS->getOperand(1); Value *R11, *R12; bool Ok = false; - if (decomposeBitTestICmp(RHS, PredR, R11, R12, R2)) { + if (decomposeBitTestICmp(R1, R2, PredR, R11, R12, R2)) { if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { A = R11; D = R12; @@ -444,7 +365,7 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, E = R2; R1 = nullptr; Ok = true; - } else if (R1->getType()->isIntegerTy()) { + } else { if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) { // As before, model no mask as a trivial mask if it'll let us do an // optimization. @@ -470,7 +391,7 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, return 0; // Look for ANDs on the right side of the RHS icmp. - if (!Ok && R2->getType()->isIntegerTy()) { + if (!Ok) { if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) { R11 = R2; R12 = Constant::getAllOnesValue(R2->getType()); @@ -980,17 +901,15 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, return nullptr; } -/// Optimize (fcmp)&(fcmp). NOTE: Unlike the rest of instcombine, this returns -/// a Value which should already be inserted into the function. -Value *InstCombiner::foldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { - Value *Op0LHS = LHS->getOperand(0), *Op0RHS = LHS->getOperand(1); - Value *Op1LHS = RHS->getOperand(0), *Op1RHS = RHS->getOperand(1); - FCmpInst::Predicate Op0CC = LHS->getPredicate(), Op1CC = RHS->getPredicate(); +Value *InstCombiner::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd) { + Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); + Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); + FCmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); - if (Op0LHS == Op1RHS && Op0RHS == Op1LHS) { + if (LHS0 == RHS1 && RHS0 == LHS1) { // Swap RHS operands to match LHS. - Op1CC = FCmpInst::getSwappedPredicate(Op1CC); - std::swap(Op1LHS, Op1RHS); + PredR = FCmpInst::getSwappedPredicate(PredR); + std::swap(RHS0, RHS1); } // Simplify (fcmp cc0 x, y) & (fcmp cc1 x, y). @@ -1002,31 +921,30 @@ Value *InstCombiner::foldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { // bool(R & CC0) && bool(R & CC1) // = bool((R & CC0) & (R & CC1)) // = bool(R & (CC0 & CC1)) <= by re-association, commutation, and idempotency - if (Op0LHS == Op1LHS && Op0RHS == Op1RHS) - return getFCmpValue(getFCmpCode(Op0CC) & getFCmpCode(Op1CC), Op0LHS, Op0RHS, - Builder); + // + // Since (R & CC0) and (R & CC1) are either R or 0, we actually have this: + // bool(R & CC0) || bool(R & CC1) + // = bool((R & CC0) | (R & CC1)) + // = bool(R & (CC0 | CC1)) <= by reversed distribution (contribution? ;) + if (LHS0 == RHS0 && LHS1 == RHS1) { + unsigned FCmpCodeL = getFCmpCode(PredL); + unsigned FCmpCodeR = getFCmpCode(PredR); + unsigned NewPred = IsAnd ? FCmpCodeL & FCmpCodeR : FCmpCodeL | FCmpCodeR; + return getFCmpValue(NewPred, LHS0, LHS1, Builder); + } - if (LHS->getPredicate() == FCmpInst::FCMP_ORD && - RHS->getPredicate() == FCmpInst::FCMP_ORD) { - if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) + if ((PredL == FCmpInst::FCMP_ORD && PredR == FCmpInst::FCMP_ORD && IsAnd) || + (PredL == FCmpInst::FCMP_UNO && PredR == FCmpInst::FCMP_UNO && !IsAnd)) { + if (LHS0->getType() != RHS0->getType()) return nullptr; - // (fcmp ord x, c) & (fcmp ord y, c) -> (fcmp ord x, y) - if (ConstantFP *LHSC = dyn_cast<ConstantFP>(LHS->getOperand(1))) - if (ConstantFP *RHSC = dyn_cast<ConstantFP>(RHS->getOperand(1))) { - // If either of the constants are nans, then the whole thing returns - // false. - if (LHSC->getValueAPF().isNaN() || RHSC->getValueAPF().isNaN()) - return Builder.getFalse(); - return Builder.CreateFCmpORD(LHS->getOperand(0), RHS->getOperand(0)); - } - - // Handle vector zeros. This occurs because the canonical form of - // "fcmp ord x,x" is "fcmp ord x, 0". - if (isa<ConstantAggregateZero>(LHS->getOperand(1)) && - isa<ConstantAggregateZero>(RHS->getOperand(1))) - return Builder.CreateFCmpORD(LHS->getOperand(0), RHS->getOperand(0)); - return nullptr; + // FCmp canonicalization ensures that (fcmp ord/uno X, X) and + // (fcmp ord/uno X, C) will be transformed to (fcmp X, 0.0). + if (match(LHS1, m_Zero()) && LHS1 == RHS1) + // Ignore the constants because they are obviously not NANs: + // (fcmp ord x, 0.0) & (fcmp ord y, 0.0) -> (fcmp ord x, y) + // (fcmp uno x, 0.0) | (fcmp uno y, 0.0) -> (fcmp uno x, y) + return Builder.CreateFCmp(PredL, LHS0, RHS0); } return nullptr; @@ -1069,30 +987,24 @@ bool InstCombiner::shouldOptimizeCast(CastInst *CI) { if (isEliminableCastPair(PrecedingCI, CI)) return false; - // If this is a vector sext from a compare, then we don't want to break the - // idiom where each element of the extended vector is either zero or all ones. - if (CI->getOpcode() == Instruction::SExt && - isa<CmpInst>(CastSrc) && CI->getDestTy()->isVectorTy()) - return false; - return true; } /// Fold {and,or,xor} (cast X), C. static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast, InstCombiner::BuilderTy &Builder) { - Constant *C; - if (!match(Logic.getOperand(1), m_Constant(C))) + Constant *C = dyn_cast<Constant>(Logic.getOperand(1)); + if (!C) return nullptr; auto LogicOpc = Logic.getOpcode(); Type *DestTy = Logic.getType(); Type *SrcTy = Cast->getSrcTy(); - // Move the logic operation ahead of a zext if the constant is unchanged in - // the smaller source type. Performing the logic in a smaller type may provide - // more information to later folds, and the smaller logic instruction may be - // cheaper (particularly in the case of vectors). + // Move the logic operation ahead of a zext or sext if the constant is + // unchanged in the smaller source type. Performing the logic in a smaller + // type may provide more information to later folds, and the smaller logic + // instruction may be cheaper (particularly in the case of vectors). Value *X; if (match(Cast, m_OneUse(m_ZExt(m_Value(X))))) { Constant *TruncC = ConstantExpr::getTrunc(C, SrcTy); @@ -1104,6 +1016,16 @@ static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast, } } + if (match(Cast, m_OneUse(m_SExt(m_Value(X))))) { + Constant *TruncC = ConstantExpr::getTrunc(C, SrcTy); + Constant *SextTruncC = ConstantExpr::getSExt(TruncC, DestTy); + if (SextTruncC == C) { + // LogicOpc (sext X), C --> sext (LogicOpc X, C) + Value *NewOp = Builder.CreateBinOp(LogicOpc, X, TruncC); + return new SExtInst(NewOp, DestTy); + } + } + return nullptr; } @@ -1167,38 +1089,9 @@ Instruction *InstCombiner::foldCastedBitwiseLogic(BinaryOperator &I) { // cast is otherwise not optimizable. This happens for vector sexts. FCmpInst *FCmp0 = dyn_cast<FCmpInst>(Cast0Src); FCmpInst *FCmp1 = dyn_cast<FCmpInst>(Cast1Src); - if (FCmp0 && FCmp1) { - Value *Res = LogicOpc == Instruction::And ? foldAndOfFCmps(FCmp0, FCmp1) - : foldOrOfFCmps(FCmp0, FCmp1); - if (Res) - return CastInst::Create(CastOpcode, Res, DestTy); - return nullptr; - } - - return nullptr; -} - -static Instruction *foldBoolSextMaskToSelect(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - // Canonicalize SExt or Not to the LHS - if (match(Op1, m_SExt(m_Value())) || match(Op1, m_Not(m_Value()))) { - std::swap(Op0, Op1); - } - - // Fold (and (sext bool to A), B) --> (select bool, B, 0) - Value *X = nullptr; - if (match(Op0, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { - Value *Zero = Constant::getNullValue(Op1->getType()); - return SelectInst::Create(X, Op1, Zero); - } - - // Fold (and ~(sext bool to A), B) --> (select bool, 0, B) - if (match(Op0, m_Not(m_SExt(m_Value(X)))) && - X->getType()->isIntOrIntVectorTy(1)) { - Value *Zero = Constant::getNullValue(Op0->getType()); - return SelectInst::Create(X, Zero, Op1); - } + if (FCmp0 && FCmp1) + if (Value *R = foldLogicOfFCmps(FCmp0, FCmp1, LogicOpc == Instruction::And)) + return CastInst::Create(CastOpcode, R, DestTy); return nullptr; } @@ -1284,14 +1177,61 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (Value *V = SimplifyBSwap(I, Builder)) return replaceInstUsesWith(I, V); - if (match(Op1, m_One())) { - // (1 << x) & 1 --> zext(x == 0) - // (1 >> x) & 1 --> zext(x == 0) - Value *X; - if (match(Op0, m_OneUse(m_LogicalShift(m_One(), m_Value(X))))) { + const APInt *C; + if (match(Op1, m_APInt(C))) { + Value *X, *Y; + if (match(Op0, m_OneUse(m_LogicalShift(m_One(), m_Value(X)))) && + C->isOneValue()) { + // (1 << X) & 1 --> zext(X == 0) + // (1 >> X) & 1 --> zext(X == 0) Value *IsZero = Builder.CreateICmpEQ(X, ConstantInt::get(I.getType(), 0)); return new ZExtInst(IsZero, I.getType()); } + + const APInt *XorC; + if (match(Op0, m_OneUse(m_Xor(m_Value(X), m_APInt(XorC))))) { + // (X ^ C1) & C2 --> (X & C2) ^ (C1&C2) + Constant *NewC = ConstantInt::get(I.getType(), *C & *XorC); + Value *And = Builder.CreateAnd(X, Op1); + And->takeName(Op0); + return BinaryOperator::CreateXor(And, NewC); + } + + const APInt *OrC; + if (match(Op0, m_OneUse(m_Or(m_Value(X), m_APInt(OrC))))) { + // (X | C1) & C2 --> (X & C2^(C1&C2)) | (C1&C2) + // NOTE: This reduces the number of bits set in the & mask, which + // can expose opportunities for store narrowing for scalars. + // NOTE: SimplifyDemandedBits should have already removed bits from C1 + // that aren't set in C2. Meaning we can replace (C1&C2) with C1 in + // above, but this feels safer. + APInt Together = *C & *OrC; + Value *And = Builder.CreateAnd(X, ConstantInt::get(I.getType(), + Together ^ *C)); + And->takeName(Op0); + return BinaryOperator::CreateOr(And, ConstantInt::get(I.getType(), + Together)); + } + + // If the mask is only needed on one incoming arm, push the 'and' op up. + if (match(Op0, m_OneUse(m_Xor(m_Value(X), m_Value(Y)))) || + match(Op0, m_OneUse(m_Or(m_Value(X), m_Value(Y))))) { + APInt NotAndMask(~(*C)); + BinaryOperator::BinaryOps BinOp = cast<BinaryOperator>(Op0)->getOpcode(); + if (MaskedValueIsZero(X, NotAndMask, 0, &I)) { + // Not masking anything out for the LHS, move mask to RHS. + // and ({x}or X, Y), C --> {x}or X, (and Y, C) + Value *NewRHS = Builder.CreateAnd(Y, Op1, Y->getName() + ".masked"); + return BinaryOperator::Create(BinOp, X, NewRHS); + } + if (!isa<Constant>(Y) && MaskedValueIsZero(Y, NotAndMask, 0, &I)) { + // Not masking anything out for the RHS, move mask to LHS. + // and ({x}or X, Y), C --> {x}or (and X, C), Y + Value *NewLHS = Builder.CreateAnd(X, Op1, X->getName() + ".masked"); + return BinaryOperator::Create(BinOp, NewLHS, Y); + } + } + } if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(Op1)) { @@ -1299,34 +1239,6 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { // Optimize a variety of ((val OP C1) & C2) combinations... if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) { - Value *Op0LHS = Op0I->getOperand(0); - Value *Op0RHS = Op0I->getOperand(1); - switch (Op0I->getOpcode()) { - default: break; - case Instruction::Xor: - case Instruction::Or: { - // If the mask is only needed on one incoming arm, push it up. - if (!Op0I->hasOneUse()) break; - - APInt NotAndRHS(~AndRHSMask); - if (MaskedValueIsZero(Op0LHS, NotAndRHS, 0, &I)) { - // Not masking anything out for the LHS, move to RHS. - Value *NewRHS = Builder.CreateAnd(Op0RHS, AndRHS, - Op0RHS->getName()+".masked"); - return BinaryOperator::Create(Op0I->getOpcode(), Op0LHS, NewRHS); - } - if (!isa<Constant>(Op0RHS) && - MaskedValueIsZero(Op0RHS, NotAndRHS, 0, &I)) { - // Not masking anything out for the RHS, move to LHS. - Value *NewLHS = Builder.CreateAnd(Op0LHS, AndRHS, - Op0LHS->getName()+".masked"); - return BinaryOperator::Create(Op0I->getOpcode(), NewLHS, Op0RHS); - } - - break; - } - } - // ((C1 OP zext(X)) & C2) -> zext((C1-X) & C2) if C2 fits in the bitwidth // of X and OP behaves well when given trunc(C1) and X. switch (Op0I->getOpcode()) { @@ -1343,6 +1255,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (AndRHSMask.isIntN(X->getType()->getScalarSizeInBits())) { auto *TruncC1 = ConstantExpr::getTrunc(C1, X->getType()); Value *BinOp; + Value *Op0LHS = Op0I->getOperand(0); if (isa<ZExtInst>(Op0LHS)) BinOp = Builder.CreateBinOp(Op0I->getOpcode(), X, TruncC1); else @@ -1467,17 +1380,22 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { } } - // If and'ing two fcmp, try combine them into one. if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) - if (Value *Res = foldAndOfFCmps(LHS, RHS)) + if (Value *Res = foldLogicOfFCmps(LHS, RHS, true)) return replaceInstUsesWith(I, Res); if (Instruction *CastedAnd = foldCastedBitwiseLogic(I)) return CastedAnd; - if (Instruction *Select = foldBoolSextMaskToSelect(I)) - return Select; + // and(sext(A), B) / and(B, sext(A)) --> A ? B : 0, where A is i1 or <N x i1>. + Value *A; + if (match(Op0, m_OneUse(m_SExt(m_Value(A)))) && + A->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(A, Op1, Constant::getNullValue(I.getType())); + if (match(Op1, m_OneUse(m_SExt(m_Value(A)))) && + A->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(A, Op0, Constant::getNullValue(I.getType())); return Changed ? &I : nullptr; } @@ -1567,8 +1485,9 @@ static Value *getSelectCondition(Value *A, Value *B, // If both operands are constants, see if the constants are inverse bitmasks. Constant *AC, *BC; if (match(A, m_Constant(AC)) && match(B, m_Constant(BC)) && - areInverseVectorBitmasks(AC, BC)) - return ConstantExpr::getTrunc(AC, CmpInst::makeCmpResultType(Ty)); + areInverseVectorBitmasks(AC, BC)) { + return Builder.CreateZExtOrTrunc(AC, CmpInst::makeCmpResultType(Ty)); + } // If both operands are xor'd with constants using the same sexted boolean // operand, see if the constants are inverse bitmasks. @@ -1832,120 +1751,6 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, return nullptr; } -/// Optimize (fcmp)|(fcmp). NOTE: Unlike the rest of instcombine, this returns -/// a Value which should already be inserted into the function. -Value *InstCombiner::foldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { - Value *Op0LHS = LHS->getOperand(0), *Op0RHS = LHS->getOperand(1); - Value *Op1LHS = RHS->getOperand(0), *Op1RHS = RHS->getOperand(1); - FCmpInst::Predicate Op0CC = LHS->getPredicate(), Op1CC = RHS->getPredicate(); - - if (Op0LHS == Op1RHS && Op0RHS == Op1LHS) { - // Swap RHS operands to match LHS. - Op1CC = FCmpInst::getSwappedPredicate(Op1CC); - std::swap(Op1LHS, Op1RHS); - } - - // Simplify (fcmp cc0 x, y) | (fcmp cc1 x, y). - // This is a similar transformation to the one in FoldAndOfFCmps. - // - // Since (R & CC0) and (R & CC1) are either R or 0, we actually have this: - // bool(R & CC0) || bool(R & CC1) - // = bool((R & CC0) | (R & CC1)) - // = bool(R & (CC0 | CC1)) <= by reversed distribution (contribution? ;) - if (Op0LHS == Op1LHS && Op0RHS == Op1RHS) - return getFCmpValue(getFCmpCode(Op0CC) | getFCmpCode(Op1CC), Op0LHS, Op0RHS, - Builder); - - if (LHS->getPredicate() == FCmpInst::FCMP_UNO && - RHS->getPredicate() == FCmpInst::FCMP_UNO && - LHS->getOperand(0)->getType() == RHS->getOperand(0)->getType()) { - if (ConstantFP *LHSC = dyn_cast<ConstantFP>(LHS->getOperand(1))) - if (ConstantFP *RHSC = dyn_cast<ConstantFP>(RHS->getOperand(1))) { - // If either of the constants are nans, then the whole thing returns - // true. - if (LHSC->getValueAPF().isNaN() || RHSC->getValueAPF().isNaN()) - return Builder.getTrue(); - - // Otherwise, no need to compare the two constants, compare the - // rest. - return Builder.CreateFCmpUNO(LHS->getOperand(0), RHS->getOperand(0)); - } - - // Handle vector zeros. This occurs because the canonical form of - // "fcmp uno x,x" is "fcmp uno x, 0". - if (isa<ConstantAggregateZero>(LHS->getOperand(1)) && - isa<ConstantAggregateZero>(RHS->getOperand(1))) - return Builder.CreateFCmpUNO(LHS->getOperand(0), RHS->getOperand(0)); - - return nullptr; - } - - return nullptr; -} - -/// This helper function folds: -/// -/// ((A | B) & C1) | (B & C2) -/// -/// into: -/// -/// (A & C1) | B -/// -/// when the XOR of the two constants is "all ones" (-1). -static Instruction *FoldOrWithConstants(BinaryOperator &I, Value *Op, - Value *A, Value *B, Value *C, - InstCombiner::BuilderTy &Builder) { - ConstantInt *CI1 = dyn_cast<ConstantInt>(C); - if (!CI1) return nullptr; - - Value *V1 = nullptr; - ConstantInt *CI2 = nullptr; - if (!match(Op, m_And(m_Value(V1), m_ConstantInt(CI2)))) return nullptr; - - APInt Xor = CI1->getValue() ^ CI2->getValue(); - if (!Xor.isAllOnesValue()) return nullptr; - - if (V1 == A || V1 == B) { - Value *NewOp = Builder.CreateAnd((V1 == A) ? B : A, CI1); - return BinaryOperator::CreateOr(NewOp, V1); - } - - return nullptr; -} - -/// \brief This helper function folds: -/// -/// ((A ^ B) & C1) | (B & C2) -/// -/// into: -/// -/// (A & C1) ^ B -/// -/// when the XOR of the two constants is "all ones" (-1). -static Instruction *FoldXorWithConstants(BinaryOperator &I, Value *Op, - Value *A, Value *B, Value *C, - InstCombiner::BuilderTy &Builder) { - ConstantInt *CI1 = dyn_cast<ConstantInt>(C); - if (!CI1) - return nullptr; - - Value *V1 = nullptr; - ConstantInt *CI2 = nullptr; - if (!match(Op, m_And(m_Value(V1), m_ConstantInt(CI2)))) - return nullptr; - - APInt Xor = CI1->getValue() ^ CI2->getValue(); - if (!Xor.isAllOnesValue()) - return nullptr; - - if (V1 == A || V1 == B) { - Value *NewOp = Builder.CreateAnd(V1 == A ? B : A, CI1); - return BinaryOperator::CreateXor(NewOp, V1); - } - - return nullptr; -} - // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. @@ -2011,10 +1816,10 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { Value *C = nullptr, *D = nullptr; if (match(Op0, m_And(m_Value(A), m_Value(C))) && match(Op1, m_And(m_Value(B), m_Value(D)))) { - Value *V1 = nullptr, *V2 = nullptr; ConstantInt *C1 = dyn_cast<ConstantInt>(C); ConstantInt *C2 = dyn_cast<ConstantInt>(D); if (C1 && C2) { // (A & C1)|(B & C2) + Value *V1 = nullptr, *V2 = nullptr; if ((C1->getValue() & C2->getValue()).isNullValue()) { // ((V | N) & C1) | (V & C2) --> (V|N) & (C1|C2) // iff (C1&C2) == 0 and (N&~C1) == 0 @@ -2046,6 +1851,24 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { Builder.getInt(C1->getValue()|C2->getValue())); } } + + if (C1->getValue() == ~C2->getValue()) { + Value *X; + + // ((X|B)&C1)|(B&C2) -> (X&C1) | B iff C1 == ~C2 + if (match(A, m_c_Or(m_Value(X), m_Specific(B)))) + return BinaryOperator::CreateOr(Builder.CreateAnd(X, C1), B); + // (A&C2)|((X|A)&C1) -> (X&C2) | A iff C1 == ~C2 + if (match(B, m_c_Or(m_Specific(A), m_Value(X)))) + return BinaryOperator::CreateOr(Builder.CreateAnd(X, C2), A); + + // ((X^B)&C1)|(B&C2) -> (X&C1) ^ B iff C1 == ~C2 + if (match(A, m_c_Xor(m_Value(X), m_Specific(B)))) + return BinaryOperator::CreateXor(Builder.CreateAnd(X, C1), B); + // (A&C2)|((X^A)&C1) -> (X&C2) ^ A iff C1 == ~C2 + if (match(B, m_c_Xor(m_Specific(A), m_Value(X)))) + return BinaryOperator::CreateXor(Builder.CreateAnd(X, C2), A); + } } // Don't try to form a select if it's unlikely that we'll get rid of at @@ -2070,27 +1893,6 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (Value *V = matchSelectFromAndOr(D, B, C, A, Builder)) return replaceInstUsesWith(I, V); } - - // ((A|B)&1)|(B&-2) -> (A&1) | B - if (match(A, m_c_Or(m_Value(V1), m_Specific(B)))) { - if (Instruction *Ret = FoldOrWithConstants(I, Op1, V1, B, C, Builder)) - return Ret; - } - // (B&-2)|((A|B)&1) -> (A&1) | B - if (match(B, m_c_Or(m_Specific(A), m_Value(V1)))) { - if (Instruction *Ret = FoldOrWithConstants(I, Op0, A, V1, D, Builder)) - return Ret; - } - // ((A^B)&1)|(B&-2) -> (A&1) ^ B - if (match(A, m_c_Xor(m_Value(V1), m_Specific(B)))) { - if (Instruction *Ret = FoldXorWithConstants(I, Op1, V1, B, C, Builder)) - return Ret; - } - // (B&-2)|((A^B)&1) -> (A&1) ^ B - if (match(B, m_c_Xor(m_Specific(A), m_Value(V1)))) { - if (Instruction *Ret = FoldXorWithConstants(I, Op0, A, V1, D, Builder)) - return Ret; - } } // (A ^ B) | ((B ^ C) ^ A) -> (A ^ B) | C @@ -2182,10 +1984,9 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { } } - // (fcmp uno x, c) | (fcmp uno y, c) -> (fcmp uno x, y) if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) - if (Value *Res = foldOrOfFCmps(LHS, RHS)) + if (Value *Res = foldLogicOfFCmps(LHS, RHS, false)) return replaceInstUsesWith(I, Res); if (Instruction *CastedOr = foldCastedBitwiseLogic(I)) @@ -2434,60 +2235,51 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { return replaceInstUsesWith(I, Op0); } - if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) { - // fold (xor(zext(cmp)), 1) and (xor(sext(cmp)), -1) to ext(!cmp). - if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) { - if (CmpInst *CI = dyn_cast<CmpInst>(Op0C->getOperand(0))) { - if (CI->hasOneUse() && Op0C->hasOneUse()) { - Instruction::CastOps Opcode = Op0C->getOpcode(); - if ((Opcode == Instruction::ZExt || Opcode == Instruction::SExt) && - (RHSC == ConstantExpr::getCast(Opcode, Builder.getTrue(), - Op0C->getDestTy()))) { - CI->setPredicate(CI->getInversePredicate()); - return CastInst::Create(Opcode, CI, Op0C->getType()); - } + { + const APInt *RHSC; + if (match(Op1, m_APInt(RHSC))) { + Value *X; + const APInt *C; + if (match(Op0, m_Sub(m_APInt(C), m_Value(X)))) { + // ~(c-X) == X-c-1 == X+(-c-1) + if (RHSC->isAllOnesValue()) { + Constant *NewC = ConstantInt::get(I.getType(), -(*C) - 1); + return BinaryOperator::CreateAdd(X, NewC); + } + if (RHSC->isSignMask()) { + // (C - X) ^ signmask -> (C + signmask - X) + Constant *NewC = ConstantInt::get(I.getType(), *C + *RHSC); + return BinaryOperator::CreateSub(NewC, X); + } + } else if (match(Op0, m_Add(m_Value(X), m_APInt(C)))) { + // ~(X-c) --> (-c-1)-X + if (RHSC->isAllOnesValue()) { + Constant *NewC = ConstantInt::get(I.getType(), -(*C) - 1); + return BinaryOperator::CreateSub(NewC, X); } + if (RHSC->isSignMask()) { + // (X + C) ^ signmask -> (X + C + signmask) + Constant *NewC = ConstantInt::get(I.getType(), *C + *RHSC); + return BinaryOperator::CreateAdd(X, NewC); + } + } + + // (X|C1)^C2 -> X^(C1^C2) iff X&~C1 == 0 + if (match(Op0, m_Or(m_Value(X), m_APInt(C))) && + MaskedValueIsZero(X, *C, 0, &I)) { + Constant *NewC = ConstantInt::get(I.getType(), *C ^ *RHSC); + Worklist.Add(cast<Instruction>(Op0)); + I.setOperand(0, X); + I.setOperand(1, NewC); + return &I; } } + } + if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) { if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) { - // ~(c-X) == X-c-1 == X+(-c-1) - if (Op0I->getOpcode() == Instruction::Sub && RHSC->isMinusOne()) - if (Constant *Op0I0C = dyn_cast<Constant>(Op0I->getOperand(0))) { - Constant *NegOp0I0C = ConstantExpr::getNeg(Op0I0C); - return BinaryOperator::CreateAdd(Op0I->getOperand(1), - SubOne(NegOp0I0C)); - } - if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) { - if (Op0I->getOpcode() == Instruction::Add) { - // ~(X-c) --> (-c-1)-X - if (RHSC->isMinusOne()) { - Constant *NegOp0CI = ConstantExpr::getNeg(Op0CI); - return BinaryOperator::CreateSub(SubOne(NegOp0CI), - Op0I->getOperand(0)); - } else if (RHSC->getValue().isSignMask()) { - // (X + C) ^ signmask -> (X + C + signmask) - Constant *C = Builder.getInt(RHSC->getValue() + Op0CI->getValue()); - return BinaryOperator::CreateAdd(Op0I->getOperand(0), C); - - } - } else if (Op0I->getOpcode() == Instruction::Or) { - // (X|C1)^C2 -> X^(C1|C2) iff X&~C1 == 0 - if (MaskedValueIsZero(Op0I->getOperand(0), Op0CI->getValue(), - 0, &I)) { - Constant *NewRHS = ConstantExpr::getOr(Op0CI, RHSC); - // Anything in both C1 and C2 is known to be zero, remove it from - // NewRHS. - Constant *CommonBits = ConstantExpr::getAnd(Op0CI, RHSC); - NewRHS = ConstantExpr::getAnd(NewRHS, - ConstantExpr::getNot(CommonBits)); - Worklist.Add(Op0I); - I.setOperand(0, Op0I->getOperand(0)); - I.setOperand(1, NewRHS); - return &I; - } - } else if (Op0I->getOpcode() == Instruction::LShr) { + if (Op0I->getOpcode() == Instruction::LShr) { // ((X^C1) >> C2) ^ C3 -> (X>>C2) ^ ((C1>>C2)^C3) // E1 = "X ^ C1" BinaryOperator *E1; @@ -2605,5 +2397,25 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (Instruction *CastedXor = foldCastedBitwiseLogic(I)) return CastedXor; + // Canonicalize the shifty way to code absolute value to the common pattern. + // There are 4 potential commuted variants. Move the 'ashr' candidate to Op1. + // We're relying on the fact that we only do this transform when the shift has + // exactly 2 uses and the add has exactly 1 use (otherwise, we might increase + // instructions). + if (Op0->getNumUses() == 2) + std::swap(Op0, Op1); + + const APInt *ShAmt; + Type *Ty = I.getType(); + if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) && + Op1->getNumUses() == 2 && *ShAmt == Ty->getScalarSizeInBits() - 1 && + match(Op0, m_OneUse(m_c_Add(m_Specific(A), m_Specific(Op1))))) { + // B = ashr i32 A, 31 ; smear the sign bit + // xor (add A, B), B ; add -1 and flip bits if negative + // --> (A < 0) ? -A : A + Value *Cmp = Builder.CreateICmpSLT(A, ConstantInt::getNullValue(Ty)); + return SelectInst::Create(Cmp, Builder.CreateNeg(A), A); + } + return Changed ? &I : nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp index 391c430dab75..aa055121e710 100644 --- a/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -16,16 +16,20 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Twine.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" @@ -40,18 +44,26 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Statepoint.h" #include "llvm/IR/Type.h" +#include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" +#include "llvm/Support/AtomicOrdering.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SimplifyLibCalls.h" #include <algorithm> #include <cassert> #include <cstdint> #include <cstring> +#include <utility> #include <vector> using namespace llvm; @@ -94,8 +106,8 @@ static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) { return ConstantVector::get(BoolVec); } -Instruction *InstCombiner::SimplifyElementUnorderedAtomicMemCpy( - ElementUnorderedAtomicMemCpyInst *AMI) { +Instruction * +InstCombiner::SimplifyElementUnorderedAtomicMemCpy(AtomicMemCpyInst *AMI) { // Try to unfold this intrinsic into sequence of explicit atomic loads and // stores. // First check that number of elements is compile time constant. @@ -515,7 +527,7 @@ static Value *simplifyX86varShift(const IntrinsicInst &II, // If all elements out of range or UNDEF, return vector of zeros/undefs. // ArithmeticShift should only hit this if they are all UNDEF. auto OutOfRange = [&](int Idx) { return (Idx < 0) || (BitWidth <= Idx); }; - if (all_of(ShiftAmts, OutOfRange)) { + if (llvm::all_of(ShiftAmts, OutOfRange)) { SmallVector<Constant *, 8> ConstantVec; for (int Idx : ShiftAmts) { if (Idx < 0) { @@ -1094,72 +1106,6 @@ static Value *simplifyX86vpermv(const IntrinsicInst &II, return Builder.CreateShuffleVector(V1, V2, ShuffleMask); } -/// The shuffle mask for a perm2*128 selects any two halves of two 256-bit -/// source vectors, unless a zero bit is set. If a zero bit is set, -/// then ignore that half of the mask and clear that half of the vector. -static Value *simplifyX86vperm2(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder) { - auto *CInt = dyn_cast<ConstantInt>(II.getArgOperand(2)); - if (!CInt) - return nullptr; - - VectorType *VecTy = cast<VectorType>(II.getType()); - ConstantAggregateZero *ZeroVector = ConstantAggregateZero::get(VecTy); - - // The immediate permute control byte looks like this: - // [1:0] - select 128 bits from sources for low half of destination - // [2] - ignore - // [3] - zero low half of destination - // [5:4] - select 128 bits from sources for high half of destination - // [6] - ignore - // [7] - zero high half of destination - - uint8_t Imm = CInt->getZExtValue(); - - bool LowHalfZero = Imm & 0x08; - bool HighHalfZero = Imm & 0x80; - - // If both zero mask bits are set, this was just a weird way to - // generate a zero vector. - if (LowHalfZero && HighHalfZero) - return ZeroVector; - - // If 0 or 1 zero mask bits are set, this is a simple shuffle. - unsigned NumElts = VecTy->getNumElements(); - unsigned HalfSize = NumElts / 2; - SmallVector<uint32_t, 8> ShuffleMask(NumElts); - - // The high bit of the selection field chooses the 1st or 2nd operand. - bool LowInputSelect = Imm & 0x02; - bool HighInputSelect = Imm & 0x20; - - // The low bit of the selection field chooses the low or high half - // of the selected operand. - bool LowHalfSelect = Imm & 0x01; - bool HighHalfSelect = Imm & 0x10; - - // Determine which operand(s) are actually in use for this instruction. - Value *V0 = LowInputSelect ? II.getArgOperand(1) : II.getArgOperand(0); - Value *V1 = HighInputSelect ? II.getArgOperand(1) : II.getArgOperand(0); - - // If needed, replace operands based on zero mask. - V0 = LowHalfZero ? ZeroVector : V0; - V1 = HighHalfZero ? ZeroVector : V1; - - // Permute low half of result. - unsigned StartIndex = LowHalfSelect ? HalfSize : 0; - for (unsigned i = 0; i < HalfSize; ++i) - ShuffleMask[i] = StartIndex + i; - - // Permute high half of result. - StartIndex = HighHalfSelect ? HalfSize : 0; - StartIndex += NumElts; - for (unsigned i = 0; i < HalfSize; ++i) - ShuffleMask[i + HalfSize] = StartIndex + i; - - return Builder.CreateShuffleVector(V0, V1, ShuffleMask); -} - /// Decode XOP integer vector comparison intrinsics. static Value *simplifyX86vpcom(const IntrinsicInst &II, InstCombiner::BuilderTy &Builder, @@ -1650,7 +1596,6 @@ static Instruction *SimplifyNVVMIntrinsic(IntrinsicInst *II, InstCombiner &IC) { // IntrinsicInstr with target-generic LLVM IR. const SimplifyAction Action = [II]() -> SimplifyAction { switch (II->getIntrinsicID()) { - // NVVM intrinsics that map directly to LLVM intrinsics. case Intrinsic::nvvm_ceil_d: return {Intrinsic::ceil, FTZ_Any}; @@ -1932,7 +1877,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (Changed) return II; } - if (auto *AMI = dyn_cast<ElementUnorderedAtomicMemCpyInst>(II)) { + if (auto *AMI = dyn_cast<AtomicMemCpyInst>(II)) { if (Constant *C = dyn_cast<Constant>(AMI->getLength())) if (C->isNullValue()) return eraseInstFromFunction(*AMI); @@ -2072,7 +2017,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } case Intrinsic::fmuladd: { // Canonicalize fast fmuladd to the separate fmul + fadd. - if (II->hasUnsafeAlgebra()) { + if (II->isFast()) { BuilderTy::FastMathFlagGuard Guard(Builder); Builder.setFastMathFlags(II->getFastMathFlags()); Value *Mul = Builder.CreateFMul(II->getArgOperand(0), @@ -2248,6 +2193,52 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } break; + case Intrinsic::x86_bmi_bextr_32: + case Intrinsic::x86_bmi_bextr_64: + case Intrinsic::x86_tbm_bextri_u32: + case Intrinsic::x86_tbm_bextri_u64: + // If the RHS is a constant we can try some simplifications. + if (auto *C = dyn_cast<ConstantInt>(II->getArgOperand(1))) { + uint64_t Shift = C->getZExtValue(); + uint64_t Length = (Shift >> 8) & 0xff; + Shift &= 0xff; + unsigned BitWidth = II->getType()->getIntegerBitWidth(); + // If the length is 0 or the shift is out of range, replace with zero. + if (Length == 0 || Shift >= BitWidth) + return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), 0)); + // If the LHS is also a constant, we can completely constant fold this. + if (auto *InC = dyn_cast<ConstantInt>(II->getArgOperand(0))) { + uint64_t Result = InC->getZExtValue() >> Shift; + if (Length > BitWidth) + Length = BitWidth; + Result &= maskTrailingOnes<uint64_t>(Length); + return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Result)); + } + // TODO should we turn this into 'and' if shift is 0? Or 'shl' if we + // are only masking bits that a shift already cleared? + } + break; + + case Intrinsic::x86_bmi_bzhi_32: + case Intrinsic::x86_bmi_bzhi_64: + // If the RHS is a constant we can try some simplifications. + if (auto *C = dyn_cast<ConstantInt>(II->getArgOperand(1))) { + uint64_t Index = C->getZExtValue() & 0xff; + unsigned BitWidth = II->getType()->getIntegerBitWidth(); + if (Index >= BitWidth) + return replaceInstUsesWith(CI, II->getArgOperand(0)); + if (Index == 0) + return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), 0)); + // If the LHS is also a constant, we can completely constant fold this. + if (auto *InC = dyn_cast<ConstantInt>(II->getArgOperand(0))) { + uint64_t Result = InC->getZExtValue(); + Result &= maskTrailingOnes<uint64_t>(Index); + return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Result)); + } + // TODO should we convert this to an AND if the RHS is constant? + } + break; + case Intrinsic::x86_vcvtph2ps_128: case Intrinsic::x86_vcvtph2ps_256: { auto Arg = II->getArgOperand(0); @@ -2333,11 +2324,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_sse2_pmovmskb_128: case Intrinsic::x86_avx_movmsk_pd_256: case Intrinsic::x86_avx_movmsk_ps_256: - case Intrinsic::x86_avx2_pmovmskb: { + case Intrinsic::x86_avx2_pmovmskb: if (Value *V = simplifyX86movmsk(*II)) return replaceInstUsesWith(*II, V); break; - } case Intrinsic::x86_sse_comieq_ss: case Intrinsic::x86_sse_comige_ss: @@ -2972,14 +2962,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } break; - case Intrinsic::x86_avx_vperm2f128_pd_256: - case Intrinsic::x86_avx_vperm2f128_ps_256: - case Intrinsic::x86_avx_vperm2f128_si_256: - case Intrinsic::x86_avx2_vperm2i128: - if (Value *V = simplifyX86vperm2(*II, Builder)) - return replaceInstUsesWith(*II, V); - break; - case Intrinsic::x86_avx_maskload_ps: case Intrinsic::x86_avx_maskload_pd: case Intrinsic::x86_avx_maskload_ps_256: @@ -3399,7 +3381,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return II; break; - } case Intrinsic::amdgcn_fmed3: { // Note this does not preserve proper sNaN behavior if IEEE-mode is enabled @@ -3560,6 +3541,21 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::amdgcn_wqm_vote: { + // wqm_vote is identity when the argument is constant. + if (!isa<Constant>(II->getArgOperand(0))) + break; + + return replaceInstUsesWith(*II, II->getArgOperand(0)); + } + case Intrinsic::amdgcn_kill: { + const ConstantInt *C = dyn_cast<ConstantInt>(II->getArgOperand(0)); + if (!C || !C->getZExtValue()) + break; + + // amdgcn.kill(i1 1) is a no-op + return eraseInstFromFunction(CI); + } case Intrinsic::stackrestore: { // If the save is right next to the restore, remove the restore. This can // happen when variable allocas are DCE'd. @@ -3611,7 +3607,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::lifetime_start: // Asan needs to poison memory to detect invalid access which is possible // even for empty lifetime range. - if (II->getFunction()->hasFnAttribute(Attribute::SanitizeAddress)) + if (II->getFunction()->hasFnAttribute(Attribute::SanitizeAddress) || + II->getFunction()->hasFnAttribute(Attribute::SanitizeHWAddress)) break; if (removeTriviallyEmptyRange(*II, Intrinsic::lifetime_start, @@ -3697,7 +3694,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, ConstantPointerNull::get(PT)); // isKnownNonNull -> nonnull attribute - if (isKnownNonNullAt(DerivedPtr, II, &DT)) + if (isKnownNonZero(DerivedPtr, DL, 0, &AC, II, &DT)) II->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); } @@ -3740,7 +3737,6 @@ Instruction *InstCombiner::visitFenceInst(FenceInst &FI) { } // InvokeInst simplification -// Instruction *InstCombiner::visitInvokeInst(InvokeInst &II) { return visitCallSite(&II); } @@ -3784,7 +3780,7 @@ Instruction *InstCombiner::tryOptimizeCall(CallInst *CI) { auto InstCombineRAUW = [this](Instruction *From, Value *With) { replaceInstUsesWith(*From, With); }; - LibCallSimplifier Simplifier(DL, &TLI, InstCombineRAUW); + LibCallSimplifier Simplifier(DL, &TLI, ORE, InstCombineRAUW); if (Value *With = Simplifier.optimizeCall(CI)) { ++NumSimplified; return CI->use_empty() ? CI : replaceInstUsesWith(*CI, With); @@ -3853,7 +3849,6 @@ static IntrinsicInst *findInitTrampolineFromBB(IntrinsicInst *AdjustTramp, // Given a call to llvm.adjust.trampoline, find and return the corresponding // call to llvm.init.trampoline if the call to the trampoline can be optimized // to a direct call to a function. Otherwise return NULL. -// static IntrinsicInst *findInitTrampoline(Value *Callee) { Callee = Callee->stripPointerCasts(); IntrinsicInst *AdjustTramp = dyn_cast<IntrinsicInst>(Callee); @@ -3886,7 +3881,7 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { for (Value *V : CS.args()) { if (V->getType()->isPointerTy() && !CS.paramHasAttr(ArgNo, Attribute::NonNull) && - isKnownNonNullAt(V, CS.getInstruction(), &DT)) + isKnownNonZero(V, DL, 0, &AC, CS.getInstruction(), &DT)) ArgNos.push_back(ArgNo); ArgNo++; } @@ -4021,7 +4016,6 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { // Okay, this is a cast from a function to a different type. Unless doing so // would cause a type conversion of one of our arguments, change this call to // be a direct call with arguments casted to the appropriate types. - // FunctionType *FT = Callee->getFunctionType(); Type *OldRetTy = Caller->getType(); Type *NewRetTy = FT->getReturnType(); diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index dfdfd3e9da84..178c8eaf2502 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -235,8 +235,8 @@ Instruction::CastOps InstCombiner::isEliminableCastPair(const CastInst *CI1, Type *MidTy = CI1->getDestTy(); Type *DstTy = CI2->getDestTy(); - Instruction::CastOps firstOp = Instruction::CastOps(CI1->getOpcode()); - Instruction::CastOps secondOp = Instruction::CastOps(CI2->getOpcode()); + Instruction::CastOps firstOp = CI1->getOpcode(); + Instruction::CastOps secondOp = CI2->getOpcode(); Type *SrcIntPtrTy = SrcTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(SrcTy) : nullptr; Type *MidIntPtrTy = @@ -346,29 +346,50 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC, } break; } - case Instruction::Shl: + case Instruction::Shl: { // If we are truncating the result of this SHL, and if it's a shift of a // constant amount, we can always perform a SHL in a smaller type. - if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) { + const APInt *Amt; + if (match(I->getOperand(1), m_APInt(Amt))) { uint32_t BitWidth = Ty->getScalarSizeInBits(); - if (CI->getLimitedValue(BitWidth) < BitWidth) + if (Amt->getLimitedValue(BitWidth) < BitWidth) return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); } break; - case Instruction::LShr: + } + case Instruction::LShr: { // If this is a truncate of a logical shr, we can truncate it to a smaller // lshr iff we know that the bits we would otherwise be shifting in are // already zeros. - if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) { + const APInt *Amt; + if (match(I->getOperand(1), m_APInt(Amt))) { uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); uint32_t BitWidth = Ty->getScalarSizeInBits(); if (IC.MaskedValueIsZero(I->getOperand(0), APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth), 0, CxtI) && - CI->getLimitedValue(BitWidth) < BitWidth) { + Amt->getLimitedValue(BitWidth) < BitWidth) { return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); } } break; + } + case Instruction::AShr: { + // If this is a truncate of an arithmetic shr, we can truncate it to a + // smaller ashr iff we know that all the bits from the sign bit of the + // original type and the sign bit of the truncate type are similar. + // TODO: It is enough to check that the bits we would be shifting in are + // similar to sign bit of the truncate type. + const APInt *Amt; + if (match(I->getOperand(1), m_APInt(Amt))) { + uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); + uint32_t BitWidth = Ty->getScalarSizeInBits(); + if (Amt->getLimitedValue(BitWidth) < BitWidth && + OrigBitWidth - BitWidth < + IC.ComputeNumSignBits(I->getOperand(0), 0, CxtI)) + return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); + } + break; + } case Instruction::Trunc: // trunc(trunc(x)) -> trunc(x) return true; @@ -443,24 +464,130 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC) { return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt)); } -/// Try to narrow the width of bitwise logic instructions with constants. -Instruction *InstCombiner::shrinkBitwiseLogic(TruncInst &Trunc) { +/// Rotate left/right may occur in a wider type than necessary because of type +/// promotion rules. Try to narrow all of the component instructions. +Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) { + assert((isa<VectorType>(Trunc.getSrcTy()) || + shouldChangeType(Trunc.getSrcTy(), Trunc.getType())) && + "Don't narrow to an illegal scalar type"); + + // First, find an or'd pair of opposite shifts with the same shifted operand: + // trunc (or (lshr ShVal, ShAmt0), (shl ShVal, ShAmt1)) + Value *Or0, *Or1; + if (!match(Trunc.getOperand(0), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1))))) + return nullptr; + + Value *ShVal, *ShAmt0, *ShAmt1; + if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal), m_Value(ShAmt0)))) || + !match(Or1, m_OneUse(m_LogicalShift(m_Specific(ShVal), m_Value(ShAmt1))))) + return nullptr; + + auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode(); + auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode(); + if (ShiftOpcode0 == ShiftOpcode1) + return nullptr; + + // The shift amounts must add up to the narrow bit width. + Value *ShAmt; + bool SubIsOnLHS; + Type *DestTy = Trunc.getType(); + unsigned NarrowWidth = DestTy->getScalarSizeInBits(); + if (match(ShAmt0, + m_OneUse(m_Sub(m_SpecificInt(NarrowWidth), m_Specific(ShAmt1))))) { + ShAmt = ShAmt1; + SubIsOnLHS = true; + } else if (match(ShAmt1, m_OneUse(m_Sub(m_SpecificInt(NarrowWidth), + m_Specific(ShAmt0))))) { + ShAmt = ShAmt0; + SubIsOnLHS = false; + } else { + return nullptr; + } + + // The shifted value must have high zeros in the wide type. Typically, this + // will be a zext, but it could also be the result of an 'and' or 'shift'. + unsigned WideWidth = Trunc.getSrcTy()->getScalarSizeInBits(); + APInt HiBitMask = APInt::getHighBitsSet(WideWidth, WideWidth - NarrowWidth); + if (!MaskedValueIsZero(ShVal, HiBitMask, 0, &Trunc)) + return nullptr; + + // We have an unnecessarily wide rotate! + // trunc (or (lshr ShVal, ShAmt), (shl ShVal, BitWidth - ShAmt)) + // Narrow it down to eliminate the zext/trunc: + // or (lshr trunc(ShVal), ShAmt0'), (shl trunc(ShVal), ShAmt1') + Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy); + Value *NegShAmt = Builder.CreateNeg(NarrowShAmt); + + // Mask both shift amounts to ensure there's no UB from oversized shifts. + Constant *MaskC = ConstantInt::get(DestTy, NarrowWidth - 1); + Value *MaskedShAmt = Builder.CreateAnd(NarrowShAmt, MaskC); + Value *MaskedNegShAmt = Builder.CreateAnd(NegShAmt, MaskC); + + // Truncate the original value and use narrow ops. + Value *X = Builder.CreateTrunc(ShVal, DestTy); + Value *NarrowShAmt0 = SubIsOnLHS ? MaskedNegShAmt : MaskedShAmt; + Value *NarrowShAmt1 = SubIsOnLHS ? MaskedShAmt : MaskedNegShAmt; + Value *NarrowSh0 = Builder.CreateBinOp(ShiftOpcode0, X, NarrowShAmt0); + Value *NarrowSh1 = Builder.CreateBinOp(ShiftOpcode1, X, NarrowShAmt1); + return BinaryOperator::CreateOr(NarrowSh0, NarrowSh1); +} + +/// Try to narrow the width of math or bitwise logic instructions by pulling a +/// truncate ahead of binary operators. +/// TODO: Transforms for truncated shifts should be moved into here. +Instruction *InstCombiner::narrowBinOp(TruncInst &Trunc) { Type *SrcTy = Trunc.getSrcTy(); Type *DestTy = Trunc.getType(); - if (isa<IntegerType>(SrcTy) && !shouldChangeType(SrcTy, DestTy)) + if (!isa<VectorType>(SrcTy) && !shouldChangeType(SrcTy, DestTy)) return nullptr; - BinaryOperator *LogicOp; - Constant *C; - if (!match(Trunc.getOperand(0), m_OneUse(m_BinOp(LogicOp))) || - !LogicOp->isBitwiseLogicOp() || - !match(LogicOp->getOperand(1), m_Constant(C))) + BinaryOperator *BinOp; + if (!match(Trunc.getOperand(0), m_OneUse(m_BinOp(BinOp)))) return nullptr; - // trunc (logic X, C) --> logic (trunc X, C') - Constant *NarrowC = ConstantExpr::getTrunc(C, DestTy); - Value *NarrowOp0 = Builder.CreateTrunc(LogicOp->getOperand(0), DestTy); - return BinaryOperator::Create(LogicOp->getOpcode(), NarrowOp0, NarrowC); + Value *BinOp0 = BinOp->getOperand(0); + Value *BinOp1 = BinOp->getOperand(1); + switch (BinOp->getOpcode()) { + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: { + Constant *C; + if (match(BinOp0, m_Constant(C))) { + // trunc (binop C, X) --> binop (trunc C', X) + Constant *NarrowC = ConstantExpr::getTrunc(C, DestTy); + Value *TruncX = Builder.CreateTrunc(BinOp1, DestTy); + return BinaryOperator::Create(BinOp->getOpcode(), NarrowC, TruncX); + } + if (match(BinOp1, m_Constant(C))) { + // trunc (binop X, C) --> binop (trunc X, C') + Constant *NarrowC = ConstantExpr::getTrunc(C, DestTy); + Value *TruncX = Builder.CreateTrunc(BinOp0, DestTy); + return BinaryOperator::Create(BinOp->getOpcode(), TruncX, NarrowC); + } + Value *X; + if (match(BinOp0, m_ZExtOrSExt(m_Value(X))) && X->getType() == DestTy) { + // trunc (binop (ext X), Y) --> binop X, (trunc Y) + Value *NarrowOp1 = Builder.CreateTrunc(BinOp1, DestTy); + return BinaryOperator::Create(BinOp->getOpcode(), X, NarrowOp1); + } + if (match(BinOp1, m_ZExtOrSExt(m_Value(X))) && X->getType() == DestTy) { + // trunc (binop Y, (ext X)) --> binop (trunc Y), X + Value *NarrowOp0 = Builder.CreateTrunc(BinOp0, DestTy); + return BinaryOperator::Create(BinOp->getOpcode(), NarrowOp0, X); + } + break; + } + + default: break; + } + + if (Instruction *NarrowOr = narrowRotate(Trunc)) + return NarrowOr; + + return nullptr; } /// Try to narrow the width of a splat shuffle. This could be generalized to any @@ -616,7 +743,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { } } - if (Instruction *I = shrinkBitwiseLogic(CI)) + if (Instruction *I = narrowBinOp(CI)) return I; if (Instruction *I = shrinkSplatShuffle(CI, Builder)) @@ -655,13 +782,13 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, // If we are just checking for a icmp eq of a single bit and zext'ing it // to an integer, then shift the bit to the appropriate place and then // cast to integer to avoid the comparison. - if (ConstantInt *Op1C = dyn_cast<ConstantInt>(ICI->getOperand(1))) { - const APInt &Op1CV = Op1C->getValue(); + const APInt *Op1CV; + if (match(ICI->getOperand(1), m_APInt(Op1CV))) { // zext (x <s 0) to i32 --> x>>u31 true if signbit set. // zext (x >s -1) to i32 --> (x>>u31)^1 true if signbit clear. - if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV.isNullValue()) || - (ICI->getPredicate() == ICmpInst::ICMP_SGT && Op1CV.isAllOnesValue())) { + if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isNullValue()) || + (ICI->getPredicate() == ICmpInst::ICMP_SGT && Op1CV->isAllOnesValue())) { if (!DoTransform) return ICI; Value *In = ICI->getOperand(0); @@ -687,7 +814,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, // zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set. // zext (X != 1) to i32 --> X^1 iff X has only the low bit set. // zext (X != 2) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. - if ((Op1CV.isNullValue() || Op1CV.isPowerOf2()) && + if ((Op1CV->isNullValue() || Op1CV->isPowerOf2()) && // This only works for EQ and NE ICI->isEquality()) { // If Op1C some other power of two, convert: @@ -698,12 +825,10 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, if (!DoTransform) return ICI; bool isNE = ICI->getPredicate() == ICmpInst::ICMP_NE; - if (!Op1CV.isNullValue() && (Op1CV != KnownZeroMask)) { + if (!Op1CV->isNullValue() && (*Op1CV != KnownZeroMask)) { // (X&4) == 2 --> false // (X&4) != 2 --> true - Constant *Res = ConstantInt::get(Type::getInt1Ty(CI.getContext()), - isNE); - Res = ConstantExpr::getZExt(Res, CI.getType()); + Constant *Res = ConstantInt::get(CI.getType(), isNE); return replaceInstUsesWith(CI, Res); } @@ -716,7 +841,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, In->getName() + ".lobit"); } - if (!Op1CV.isNullValue() == isNE) { // Toggle the low bit. + if (!Op1CV->isNullValue() == isNE) { // Toggle the low bit. Constant *One = ConstantInt::get(In->getType(), 1); In = Builder.CreateXor(In, One); } @@ -833,17 +958,23 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, unsigned VSize = V->getType()->getScalarSizeInBits(); if (IC.MaskedValueIsZero(I->getOperand(1), APInt::getHighBitsSet(VSize, BitsToClear), - 0, CxtI)) + 0, CxtI)) { + // If this is an And instruction and all of the BitsToClear are + // known to be zero we can reset BitsToClear. + if (Opc == Instruction::And) + BitsToClear = 0; return true; + } } // Otherwise, we don't know how to analyze this BitsToClear case yet. return false; - case Instruction::Shl: + case Instruction::Shl: { // We can promote shl(x, cst) if we can promote x. Since shl overwrites the // upper bits we can reduce BitsToClear by the shift amount. - if (ConstantInt *Amt = dyn_cast<ConstantInt>(I->getOperand(1))) { + const APInt *Amt; + if (match(I->getOperand(1), m_APInt(Amt))) { if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI)) return false; uint64_t ShiftAmt = Amt->getZExtValue(); @@ -851,10 +982,12 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, return true; } return false; - case Instruction::LShr: + } + case Instruction::LShr: { // We can promote lshr(x, cst) if we can promote x. This requires the // ultimate 'and' to clear out the high zero bits we're clearing out though. - if (ConstantInt *Amt = dyn_cast<ConstantInt>(I->getOperand(1))) { + const APInt *Amt; + if (match(I->getOperand(1), m_APInt(Amt))) { if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI)) return false; BitsToClear += Amt->getZExtValue(); @@ -864,6 +997,7 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, } // Cannot promote variable LSHR. return false; + } case Instruction::Select: if (!canEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI) || !canEvaluateZExtd(I->getOperand(2), Ty, BitsToClear, IC, CxtI) || diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index a8faaecb5c34..3bc7fae77cb1 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -17,9 +17,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/GetElementPtrTypeIterator.h" @@ -37,77 +35,30 @@ using namespace PatternMatch; STATISTIC(NumSel, "Number of select opts"); -static ConstantInt *extractElement(Constant *V, Constant *Idx) { - return cast<ConstantInt>(ConstantExpr::getExtractElement(V, Idx)); -} - -static bool hasAddOverflow(ConstantInt *Result, - ConstantInt *In1, ConstantInt *In2, - bool IsSigned) { - if (!IsSigned) - return Result->getValue().ult(In1->getValue()); - - if (In2->isNegative()) - return Result->getValue().sgt(In1->getValue()); - return Result->getValue().slt(In1->getValue()); -} - /// Compute Result = In1+In2, returning true if the result overflowed for this /// type. -static bool addWithOverflow(Constant *&Result, Constant *In1, - Constant *In2, bool IsSigned = false) { - Result = ConstantExpr::getAdd(In1, In2); - - if (VectorType *VTy = dyn_cast<VectorType>(In1->getType())) { - for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { - Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i); - if (hasAddOverflow(extractElement(Result, Idx), - extractElement(In1, Idx), - extractElement(In2, Idx), - IsSigned)) - return true; - } - return false; - } - - return hasAddOverflow(cast<ConstantInt>(Result), - cast<ConstantInt>(In1), cast<ConstantInt>(In2), - IsSigned); -} - -static bool hasSubOverflow(ConstantInt *Result, - ConstantInt *In1, ConstantInt *In2, - bool IsSigned) { - if (!IsSigned) - return Result->getValue().ugt(In1->getValue()); - - if (In2->isNegative()) - return Result->getValue().slt(In1->getValue()); +static bool addWithOverflow(APInt &Result, const APInt &In1, + const APInt &In2, bool IsSigned = false) { + bool Overflow; + if (IsSigned) + Result = In1.sadd_ov(In2, Overflow); + else + Result = In1.uadd_ov(In2, Overflow); - return Result->getValue().sgt(In1->getValue()); + return Overflow; } /// Compute Result = In1-In2, returning true if the result overflowed for this /// type. -static bool subWithOverflow(Constant *&Result, Constant *In1, - Constant *In2, bool IsSigned = false) { - Result = ConstantExpr::getSub(In1, In2); - - if (VectorType *VTy = dyn_cast<VectorType>(In1->getType())) { - for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { - Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i); - if (hasSubOverflow(extractElement(Result, Idx), - extractElement(In1, Idx), - extractElement(In2, Idx), - IsSigned)) - return true; - } - return false; - } +static bool subWithOverflow(APInt &Result, const APInt &In1, + const APInt &In2, bool IsSigned = false) { + bool Overflow; + if (IsSigned) + Result = In1.ssub_ov(In2, Overflow); + else + Result = In1.usub_ov(In2, Overflow); - return hasSubOverflow(cast<ConstantInt>(Result), - cast<ConstantInt>(In1), cast<ConstantInt>(In2), - IsSigned); + return Overflow; } /// Given an icmp instruction, return true if any use of this comparison is a @@ -473,8 +424,7 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, // Look for an appropriate type: // - The type of Idx if the magic fits - // - The smallest fitting legal type if we have a DataLayout - // - Default to i32 + // - The smallest fitting legal type if (ArrayElementCount <= Idx->getType()->getIntegerBitWidth()) Ty = Idx->getType(); else @@ -1108,7 +1058,6 @@ Instruction *InstCombiner::foldAllocaCmp(ICmpInst &ICI, // because we don't allow ptrtoint. Memcpy and memmove are safe because // we don't allow stores, so src cannot point to V. case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: - case Intrinsic::dbg_declare: case Intrinsic::dbg_value: case Intrinsic::memcpy: case Intrinsic::memmove: case Intrinsic::memset: continue; default: @@ -1131,8 +1080,7 @@ Instruction *InstCombiner::foldAllocaCmp(ICmpInst &ICI, } /// Fold "icmp pred (X+CI), X". -Instruction *InstCombiner::foldICmpAddOpConst(Instruction &ICI, - Value *X, ConstantInt *CI, +Instruction *InstCombiner::foldICmpAddOpConst(Value *X, ConstantInt *CI, ICmpInst::Predicate Pred) { // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0, // so the values can never be equal. Similarly for all other "or equals" @@ -1367,6 +1315,24 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, return ExtractValueInst::Create(Call, 1, "sadd.overflow"); } +// Handle (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) +Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) { + CmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = Cmp.getOperand(0); + + if (match(Cmp.getOperand(1), m_Zero()) && Pred == ICmpInst::ICMP_SGT) { + Value *A, *B; + SelectPatternResult SPR = matchSelectPattern(X, A, B); + if (SPR.Flavor == SPF_SMIN) { + if (isKnownPositive(A, DL, 0, &AC, &Cmp, &DT)) + return new ICmpInst(Pred, B, Cmp.getOperand(1)); + if (isKnownPositive(B, DL, 0, &AC, &Cmp, &DT)) + return new ICmpInst(Pred, A, Cmp.getOperand(1)); + } + } + return nullptr; +} + // Fold icmp Pred X, C. Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { CmpInst::Predicate Pred = Cmp.getPredicate(); @@ -1398,17 +1364,6 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { return Res; } - // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) - if (C->isNullValue() && Pred == ICmpInst::ICMP_SGT) { - SelectPatternResult SPR = matchSelectPattern(X, A, B); - if (SPR.Flavor == SPF_SMIN) { - if (isKnownPositive(A, DL, 0, &AC, &Cmp, &DT)) - return new ICmpInst(Pred, B, Cmp.getOperand(1)); - if (isKnownPositive(B, DL, 0, &AC, &Cmp, &DT)) - return new ICmpInst(Pred, A, Cmp.getOperand(1)); - } - } - // FIXME: Use m_APInt to allow folds for splat constants. ConstantInt *CI = dyn_cast<ConstantInt>(Cmp.getOperand(1)); if (!CI) @@ -1462,11 +1417,11 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { /// Fold icmp (trunc X, Y), C. Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp, - Instruction *Trunc, - const APInt *C) { + TruncInst *Trunc, + const APInt &C) { ICmpInst::Predicate Pred = Cmp.getPredicate(); Value *X = Trunc->getOperand(0); - if (C->isOneValue() && C->getBitWidth() > 1) { + if (C.isOneValue() && C.getBitWidth() > 1) { // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 Value *V = nullptr; if (Pred == ICmpInst::ICMP_SLT && match(X, m_Signum(m_Value(V)))) @@ -1484,7 +1439,7 @@ Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp, // If all the high bits are known, we can do this xform. if ((Known.Zero | Known.One).countLeadingOnes() >= SrcBits - DstBits) { // Pull in the high bits from known-ones set. - APInt NewRHS = C->zext(SrcBits); + APInt NewRHS = C.zext(SrcBits); NewRHS |= Known.One & APInt::getHighBitsSet(SrcBits, SrcBits - DstBits); return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), NewRHS)); } @@ -1496,7 +1451,7 @@ Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp, /// Fold icmp (xor X, Y), C. Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, BinaryOperator *Xor, - const APInt *C) { + const APInt &C) { Value *X = Xor->getOperand(0); Value *Y = Xor->getOperand(1); const APInt *XorC; @@ -1506,8 +1461,8 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, // If this is a comparison that tests the signbit (X < 0) or (x > -1), // fold the xor. ICmpInst::Predicate Pred = Cmp.getPredicate(); - if ((Pred == ICmpInst::ICMP_SLT && C->isNullValue()) || - (Pred == ICmpInst::ICMP_SGT && C->isAllOnesValue())) { + bool TrueIfSigned = false; + if (isSignBitCheck(Cmp.getPredicate(), C, TrueIfSigned)) { // If the sign bit of the XorCst is not set, there is no change to // the operation, just stop using the Xor. @@ -1517,17 +1472,13 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, return &Cmp; } - // Was the old condition true if the operand is positive? - bool isTrueIfPositive = Pred == ICmpInst::ICMP_SGT; - - // If so, the new one isn't. - isTrueIfPositive ^= true; - - Constant *CmpConstant = cast<Constant>(Cmp.getOperand(1)); - if (isTrueIfPositive) - return new ICmpInst(ICmpInst::ICMP_SGT, X, SubOne(CmpConstant)); + // Emit the opposite comparison. + if (TrueIfSigned) + return new ICmpInst(ICmpInst::ICMP_SGT, X, + ConstantInt::getAllOnesValue(X->getType())); else - return new ICmpInst(ICmpInst::ICMP_SLT, X, AddOne(CmpConstant)); + return new ICmpInst(ICmpInst::ICMP_SLT, X, + ConstantInt::getNullValue(X->getType())); } if (Xor->hasOneUse()) { @@ -1535,7 +1486,7 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, if (!Cmp.isEquality() && XorC->isSignMask()) { Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate() : Cmp.getSignedPredicate(); - return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), *C ^ *XorC)); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), C ^ *XorC)); } // (icmp u/s (xor X ~SignMask), C) -> (icmp s/u X, (xor C ~SignMask)) @@ -1543,18 +1494,18 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate() : Cmp.getSignedPredicate(); Pred = Cmp.getSwappedPredicate(Pred); - return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), *C ^ *XorC)); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), C ^ *XorC)); } } // (icmp ugt (xor X, C), ~C) -> (icmp ult X, C) // iff -C is a power of 2 - if (Pred == ICmpInst::ICMP_UGT && *XorC == ~(*C) && (*C + 1).isPowerOf2()) + if (Pred == ICmpInst::ICMP_UGT && *XorC == ~C && (C + 1).isPowerOf2()) return new ICmpInst(ICmpInst::ICMP_ULT, X, Y); // (icmp ult (xor X, C), -C) -> (icmp uge X, C) // iff -C is a power of 2 - if (Pred == ICmpInst::ICMP_ULT && *XorC == -(*C) && C->isPowerOf2()) + if (Pred == ICmpInst::ICMP_ULT && *XorC == -C && C.isPowerOf2()) return new ICmpInst(ICmpInst::ICMP_UGE, X, Y); return nullptr; @@ -1562,7 +1513,7 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, /// Fold icmp (and (sh X, Y), C2), C1. Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, - const APInt *C1, const APInt *C2) { + const APInt &C1, const APInt &C2) { BinaryOperator *Shift = dyn_cast<BinaryOperator>(And->getOperand(0)); if (!Shift || !Shift->isShift()) return nullptr; @@ -1577,32 +1528,35 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, const APInt *C3; if (match(Shift->getOperand(1), m_APInt(C3))) { bool CanFold = false; - if (ShiftOpcode == Instruction::AShr) { - // There may be some constraints that make this possible, but nothing - // simple has been discovered yet. - CanFold = false; - } else if (ShiftOpcode == Instruction::Shl) { + if (ShiftOpcode == Instruction::Shl) { // For a left shift, we can fold if the comparison is not signed. We can // also fold a signed comparison if the mask value and comparison value // are not negative. These constraints may not be obvious, but we can // prove that they are correct using an SMT solver. - if (!Cmp.isSigned() || (!C2->isNegative() && !C1->isNegative())) + if (!Cmp.isSigned() || (!C2.isNegative() && !C1.isNegative())) CanFold = true; - } else if (ShiftOpcode == Instruction::LShr) { + } else { + bool IsAshr = ShiftOpcode == Instruction::AShr; // For a logical right shift, we can fold if the comparison is not signed. // We can also fold a signed comparison if the shifted mask value and the // shifted comparison value are not negative. These constraints may not be // obvious, but we can prove that they are correct using an SMT solver. - if (!Cmp.isSigned() || - (!C2->shl(*C3).isNegative() && !C1->shl(*C3).isNegative())) - CanFold = true; + // For an arithmetic shift right we can do the same, if we ensure + // the And doesn't use any bits being shifted in. Normally these would + // be turned into lshr by SimplifyDemandedBits, but not if there is an + // additional user. + if (!IsAshr || (C2.shl(*C3).lshr(*C3) == C2)) { + if (!Cmp.isSigned() || + (!C2.shl(*C3).isNegative() && !C1.shl(*C3).isNegative())) + CanFold = true; + } } if (CanFold) { - APInt NewCst = IsShl ? C1->lshr(*C3) : C1->shl(*C3); + APInt NewCst = IsShl ? C1.lshr(*C3) : C1.shl(*C3); APInt SameAsC1 = IsShl ? NewCst.shl(*C3) : NewCst.lshr(*C3); // Check to see if we are shifting out any of the bits being compared. - if (SameAsC1 != *C1) { + if (SameAsC1 != C1) { // If we shifted bits out, the fold is not going to work out. As a // special case, check to see if this means that the result is always // true or false now. @@ -1612,7 +1566,7 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, return replaceInstUsesWith(Cmp, ConstantInt::getTrue(Cmp.getType())); } else { Cmp.setOperand(1, ConstantInt::get(And->getType(), NewCst)); - APInt NewAndCst = IsShl ? C2->lshr(*C3) : C2->shl(*C3); + APInt NewAndCst = IsShl ? C2.lshr(*C3) : C2.shl(*C3); And->setOperand(1, ConstantInt::get(And->getType(), NewAndCst)); And->setOperand(0, Shift->getOperand(0)); Worklist.Add(Shift); // Shift is dead. @@ -1624,7 +1578,7 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, // Turn ((X >> Y) & C2) == 0 into (X & (C2 << Y)) == 0. The latter is // preferable because it allows the C2 << Y expression to be hoisted out of a // loop if Y is invariant and X is not. - if (Shift->hasOneUse() && C1->isNullValue() && Cmp.isEquality() && + if (Shift->hasOneUse() && C1.isNullValue() && Cmp.isEquality() && !Shift->isArithmeticShift() && !isa<Constant>(Shift->getOperand(0))) { // Compute C2 << Y. Value *NewShift = @@ -1643,12 +1597,12 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, /// Fold icmp (and X, C2), C1. Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, BinaryOperator *And, - const APInt *C1) { + const APInt &C1) { const APInt *C2; if (!match(And->getOperand(1), m_APInt(C2))) return nullptr; - if (!And->hasOneUse() || !And->getOperand(0)->hasOneUse()) + if (!And->hasOneUse()) return nullptr; // If the LHS is an 'and' of a truncate and we can widen the and/compare to @@ -1660,29 +1614,29 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, // set or if it is an equality comparison. Extending a relational comparison // when we're checking the sign bit would not work. Value *W; - if (match(And->getOperand(0), m_Trunc(m_Value(W))) && - (Cmp.isEquality() || (!C1->isNegative() && !C2->isNegative()))) { + if (match(And->getOperand(0), m_OneUse(m_Trunc(m_Value(W)))) && + (Cmp.isEquality() || (!C1.isNegative() && !C2->isNegative()))) { // TODO: Is this a good transform for vectors? Wider types may reduce // throughput. Should this transform be limited (even for scalars) by using // shouldChangeType()? if (!Cmp.getType()->isVectorTy()) { Type *WideType = W->getType(); unsigned WideScalarBits = WideType->getScalarSizeInBits(); - Constant *ZextC1 = ConstantInt::get(WideType, C1->zext(WideScalarBits)); + Constant *ZextC1 = ConstantInt::get(WideType, C1.zext(WideScalarBits)); Constant *ZextC2 = ConstantInt::get(WideType, C2->zext(WideScalarBits)); Value *NewAnd = Builder.CreateAnd(W, ZextC2, And->getName()); return new ICmpInst(Cmp.getPredicate(), NewAnd, ZextC1); } } - if (Instruction *I = foldICmpAndShift(Cmp, And, C1, C2)) + if (Instruction *I = foldICmpAndShift(Cmp, And, C1, *C2)) return I; // (icmp pred (and (or (lshr A, B), A), 1), 0) --> // (icmp pred (and A, (or (shl 1, B), 1), 0)) // // iff pred isn't signed - if (!Cmp.isSigned() && C1->isNullValue() && + if (!Cmp.isSigned() && C1.isNullValue() && And->getOperand(0)->hasOneUse() && match(And->getOperand(1), m_One())) { Constant *One = cast<Constant>(And->getOperand(1)); Value *Or = And->getOperand(0); @@ -1716,22 +1670,13 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, } } - // (X & C2) > C1 --> (X & C2) != 0, if any bit set in (X & C2) will produce a - // result greater than C1. - unsigned NumTZ = C2->countTrailingZeros(); - if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && NumTZ < C2->getBitWidth() && - APInt::getOneBitSet(C2->getBitWidth(), NumTZ).ugt(*C1)) { - Constant *Zero = Constant::getNullValue(And->getType()); - return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); - } - return nullptr; } /// Fold icmp (and X, Y), C. Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And, - const APInt *C) { + const APInt &C) { if (Instruction *I = foldICmpAndConstConst(Cmp, And, C)) return I; @@ -1756,7 +1701,7 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, // X & -C == -C -> X > u ~C // X & -C != -C -> X <= u ~C // iff C is a power of 2 - if (Cmp.getOperand(1) == Y && (-(*C)).isPowerOf2()) { + if (Cmp.getOperand(1) == Y && (-C).isPowerOf2()) { auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGT : CmpInst::ICMP_ULE; return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1)))); @@ -1766,7 +1711,7 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, // (X & C2) != 0 -> (trunc X) < 0 // iff C2 is a power of 2 and it masks the sign bit of a legal integer type. const APInt *C2; - if (And->hasOneUse() && C->isNullValue() && match(Y, m_APInt(C2))) { + if (And->hasOneUse() && C.isNullValue() && match(Y, m_APInt(C2))) { int32_t ExactLogBase2 = C2->exactLogBase2(); if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) { Type *NTy = IntegerType::get(Cmp.getContext(), ExactLogBase2 + 1); @@ -1784,9 +1729,9 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, /// Fold icmp (or X, Y), C. Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, - const APInt *C) { + const APInt &C) { ICmpInst::Predicate Pred = Cmp.getPredicate(); - if (C->isOneValue()) { + if (C.isOneValue()) { // icmp slt signum(V) 1 --> icmp slt V, 1 Value *V = nullptr; if (Pred == ICmpInst::ICMP_SLT && match(Or, m_Signum(m_Value(V)))) @@ -1798,12 +1743,12 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, // X | C != C --> X >u C // iff C+1 is a power of 2 (C is a bitmask of the low bits) if (Cmp.isEquality() && Cmp.getOperand(1) == Or->getOperand(1) && - (*C + 1).isPowerOf2()) { + (C + 1).isPowerOf2()) { Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT; return new ICmpInst(Pred, Or->getOperand(0), Or->getOperand(1)); } - if (!Cmp.isEquality() || !C->isNullValue() || !Or->hasOneUse()) + if (!Cmp.isEquality() || !C.isNullValue() || !Or->hasOneUse()) return nullptr; Value *P, *Q; @@ -1837,7 +1782,7 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, /// Fold icmp (mul X, Y), C. Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp, BinaryOperator *Mul, - const APInt *C) { + const APInt &C) { const APInt *MulC; if (!match(Mul->getOperand(1), m_APInt(MulC))) return nullptr; @@ -1845,7 +1790,7 @@ Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp, // If this is a test of the sign bit and the multiply is sign-preserving with // a constant operand, use the multiply LHS operand instead. ICmpInst::Predicate Pred = Cmp.getPredicate(); - if (isSignTest(Pred, *C) && Mul->hasNoSignedWrap()) { + if (isSignTest(Pred, C) && Mul->hasNoSignedWrap()) { if (MulC->isNegative()) Pred = ICmpInst::getSwappedPredicate(Pred); return new ICmpInst(Pred, Mul->getOperand(0), @@ -1857,14 +1802,14 @@ Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp, /// Fold icmp (shl 1, Y), C. static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, - const APInt *C) { + const APInt &C) { Value *Y; if (!match(Shl, m_Shl(m_One(), m_Value(Y)))) return nullptr; Type *ShiftType = Shl->getType(); - uint32_t TypeBits = C->getBitWidth(); - bool CIsPowerOf2 = C->isPowerOf2(); + unsigned TypeBits = C.getBitWidth(); + bool CIsPowerOf2 = C.isPowerOf2(); ICmpInst::Predicate Pred = Cmp.getPredicate(); if (Cmp.isUnsigned()) { // (1 << Y) pred C -> Y pred Log2(C) @@ -1881,7 +1826,7 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, // (1 << Y) >= 2147483648 -> Y >= 31 -> Y == 31 // (1 << Y) < 2147483648 -> Y < 31 -> Y != 31 - unsigned CLog2 = C->logBase2(); + unsigned CLog2 = C.logBase2(); if (CLog2 == TypeBits - 1) { if (Pred == ICmpInst::ICMP_UGE) Pred = ICmpInst::ICMP_EQ; @@ -1891,7 +1836,7 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2)); } else if (Cmp.isSigned()) { Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1); - if (C->isAllOnesValue()) { + if (C.isAllOnesValue()) { // (1 << Y) <= -1 -> Y == 31 if (Pred == ICmpInst::ICMP_SLE) return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); @@ -1899,7 +1844,7 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, // (1 << Y) > -1 -> Y != 31 if (Pred == ICmpInst::ICMP_SGT) return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); - } else if (!(*C)) { + } else if (!C) { // (1 << Y) < 0 -> Y == 31 // (1 << Y) <= 0 -> Y == 31 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) @@ -1911,7 +1856,7 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); } } else if (Cmp.isEquality() && CIsPowerOf2) { - return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, C->logBase2())); + return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, C.logBase2())); } return nullptr; @@ -1920,10 +1865,10 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, /// Fold icmp (shl X, Y), C. Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, BinaryOperator *Shl, - const APInt *C) { + const APInt &C) { const APInt *ShiftVal; if (Cmp.isEquality() && match(Shl->getOperand(0), m_APInt(ShiftVal))) - return foldICmpShlConstConst(Cmp, Shl->getOperand(1), *C, *ShiftVal); + return foldICmpShlConstConst(Cmp, Shl->getOperand(1), C, *ShiftVal); const APInt *ShiftAmt; if (!match(Shl->getOperand(1), m_APInt(ShiftAmt))) @@ -1931,7 +1876,7 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, // Check that the shift amount is in range. If not, don't perform undefined // shifts. When the shift is visited, it will be simplified. - unsigned TypeBits = C->getBitWidth(); + unsigned TypeBits = C.getBitWidth(); if (ShiftAmt->uge(TypeBits)) return nullptr; @@ -1945,15 +1890,15 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, if (Shl->hasNoSignedWrap()) { if (Pred == ICmpInst::ICMP_SGT) { // icmp Pred (shl nsw X, ShiftAmt), C --> icmp Pred X, (C >>s ShiftAmt) - APInt ShiftedC = C->ashr(*ShiftAmt); + APInt ShiftedC = C.ashr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) { // This is the same code as the SGT case, but assert the pre-condition // that is needed for this to work with equality predicates. - assert(C->ashr(*ShiftAmt).shl(*ShiftAmt) == *C && + assert(C.ashr(*ShiftAmt).shl(*ShiftAmt) == C && "Compare known true or false was not folded"); - APInt ShiftedC = C->ashr(*ShiftAmt); + APInt ShiftedC = C.ashr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } if (Pred == ICmpInst::ICMP_SLT) { @@ -1961,14 +1906,14 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, // (X << S) <=s C is equiv to X <=s (C >> S) for all C // (X << S) <s (C + 1) is equiv to X <s (C >> S) + 1 if C <s SMAX // (X << S) <s C is equiv to X <s ((C - 1) >> S) + 1 if C >s SMIN - assert(!C->isMinSignedValue() && "Unexpected icmp slt"); - APInt ShiftedC = (*C - 1).ashr(*ShiftAmt) + 1; + assert(!C.isMinSignedValue() && "Unexpected icmp slt"); + APInt ShiftedC = (C - 1).ashr(*ShiftAmt) + 1; return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } // If this is a signed comparison to 0 and the shift is sign preserving, // use the shift LHS operand instead; isSignTest may change 'Pred', so only // do that if we're sure to not continue on in this function. - if (isSignTest(Pred, *C)) + if (isSignTest(Pred, C)) return new ICmpInst(Pred, X, Constant::getNullValue(ShType)); } @@ -1978,15 +1923,15 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, if (Shl->hasNoUnsignedWrap()) { if (Pred == ICmpInst::ICMP_UGT) { // icmp Pred (shl nuw X, ShiftAmt), C --> icmp Pred X, (C >>u ShiftAmt) - APInt ShiftedC = C->lshr(*ShiftAmt); + APInt ShiftedC = C.lshr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) { // This is the same code as the UGT case, but assert the pre-condition // that is needed for this to work with equality predicates. - assert(C->lshr(*ShiftAmt).shl(*ShiftAmt) == *C && + assert(C.lshr(*ShiftAmt).shl(*ShiftAmt) == C && "Compare known true or false was not folded"); - APInt ShiftedC = C->lshr(*ShiftAmt); + APInt ShiftedC = C.lshr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } if (Pred == ICmpInst::ICMP_ULT) { @@ -1994,8 +1939,8 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, // (X << S) <=u C is equiv to X <=u (C >> S) for all C // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0 - assert(C->ugt(0) && "ult 0 should have been eliminated"); - APInt ShiftedC = (*C - 1).lshr(*ShiftAmt) + 1; + assert(C.ugt(0) && "ult 0 should have been eliminated"); + APInt ShiftedC = (C - 1).lshr(*ShiftAmt) + 1; return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } } @@ -2006,13 +1951,13 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, ShType, APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue())); Value *And = Builder.CreateAnd(X, Mask, Shl->getName() + ".mask"); - Constant *LShrC = ConstantInt::get(ShType, C->lshr(*ShiftAmt)); + Constant *LShrC = ConstantInt::get(ShType, C.lshr(*ShiftAmt)); return new ICmpInst(Pred, And, LShrC); } // Otherwise, if this is a comparison of the sign bit, simplify to and/test. bool TrueIfSigned = false; - if (Shl->hasOneUse() && isSignBitCheck(Pred, *C, TrueIfSigned)) { + if (Shl->hasOneUse() && isSignBitCheck(Pred, C, TrueIfSigned)) { // (X << 31) <s 0 --> (X & 1) != 0 Constant *Mask = ConstantInt::get( ShType, @@ -2029,13 +1974,13 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, // free on the target. It has the additional benefit of comparing to a // smaller constant that may be more target-friendly. unsigned Amt = ShiftAmt->getLimitedValue(TypeBits - 1); - if (Shl->hasOneUse() && Amt != 0 && C->countTrailingZeros() >= Amt && + if (Shl->hasOneUse() && Amt != 0 && C.countTrailingZeros() >= Amt && DL.isLegalInteger(TypeBits - Amt)) { Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt); if (ShType->isVectorTy()) TruncTy = VectorType::get(TruncTy, ShType->getVectorNumElements()); Constant *NewC = - ConstantInt::get(TruncTy, C->ashr(*ShiftAmt).trunc(TypeBits - Amt)); + ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt)); return new ICmpInst(Pred, Builder.CreateTrunc(X, TruncTy), NewC); } @@ -2045,18 +1990,18 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, /// Fold icmp ({al}shr X, Y), C. Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, BinaryOperator *Shr, - const APInt *C) { + const APInt &C) { // An exact shr only shifts out zero bits, so: // icmp eq/ne (shr X, Y), 0 --> icmp eq/ne X, 0 Value *X = Shr->getOperand(0); CmpInst::Predicate Pred = Cmp.getPredicate(); if (Cmp.isEquality() && Shr->isExact() && Shr->hasOneUse() && - C->isNullValue()) + C.isNullValue()) return new ICmpInst(Pred, X, Cmp.getOperand(1)); const APInt *ShiftVal; if (Cmp.isEquality() && match(Shr->getOperand(0), m_APInt(ShiftVal))) - return foldICmpShrConstConst(Cmp, Shr->getOperand(1), *C, *ShiftVal); + return foldICmpShrConstConst(Cmp, Shr->getOperand(1), C, *ShiftVal); const APInt *ShiftAmt; if (!match(Shr->getOperand(1), m_APInt(ShiftAmt))) @@ -2064,71 +2009,73 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, // Check that the shift amount is in range. If not, don't perform undefined // shifts. When the shift is visited it will be simplified. - unsigned TypeBits = C->getBitWidth(); + unsigned TypeBits = C.getBitWidth(); unsigned ShAmtVal = ShiftAmt->getLimitedValue(TypeBits); if (ShAmtVal >= TypeBits || ShAmtVal == 0) return nullptr; bool IsAShr = Shr->getOpcode() == Instruction::AShr; - if (!Cmp.isEquality()) { - // If we have an unsigned comparison and an ashr, we can't simplify this. - // Similarly for signed comparisons with lshr. - if (Cmp.isSigned() != IsAShr) - return nullptr; - - // Otherwise, all lshr and most exact ashr's are equivalent to a udiv/sdiv - // by a power of 2. Since we already have logic to simplify these, - // transform to div and then simplify the resultant comparison. - if (IsAShr && (!Shr->isExact() || ShAmtVal == TypeBits - 1)) - return nullptr; - - // Revisit the shift (to delete it). - Worklist.Add(Shr); - - Constant *DivCst = ConstantInt::get( - Shr->getType(), APInt::getOneBitSet(TypeBits, ShAmtVal)); - - Value *Tmp = IsAShr ? Builder.CreateSDiv(X, DivCst, "", Shr->isExact()) - : Builder.CreateUDiv(X, DivCst, "", Shr->isExact()); - - Cmp.setOperand(0, Tmp); - - // If the builder folded the binop, just return it. - BinaryOperator *TheDiv = dyn_cast<BinaryOperator>(Tmp); - if (!TheDiv) - return &Cmp; - - // Otherwise, fold this div/compare. - assert(TheDiv->getOpcode() == Instruction::SDiv || - TheDiv->getOpcode() == Instruction::UDiv); - - Instruction *Res = foldICmpDivConstant(Cmp, TheDiv, C); - assert(Res && "This div/cst should have folded!"); - return Res; + bool IsExact = Shr->isExact(); + Type *ShrTy = Shr->getType(); + // TODO: If we could guarantee that InstSimplify would handle all of the + // constant-value-based preconditions in the folds below, then we could assert + // those conditions rather than checking them. This is difficult because of + // undef/poison (PR34838). + if (IsAShr) { + if (Pred == CmpInst::ICMP_SLT || (Pred == CmpInst::ICMP_SGT && IsExact)) { + // icmp slt (ashr X, ShAmtC), C --> icmp slt X, (C << ShAmtC) + // icmp sgt (ashr exact X, ShAmtC), C --> icmp sgt X, (C << ShAmtC) + APInt ShiftedC = C.shl(ShAmtVal); + if (ShiftedC.ashr(ShAmtVal) == C) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); + } + if (Pred == CmpInst::ICMP_SGT) { + // icmp sgt (ashr X, ShAmtC), C --> icmp sgt X, ((C + 1) << ShAmtC) - 1 + APInt ShiftedC = (C + 1).shl(ShAmtVal) - 1; + if (!C.isMaxSignedValue() && !(C + 1).shl(ShAmtVal).isMinSignedValue() && + (ShiftedC + 1).ashr(ShAmtVal) == (C + 1)) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); + } + } else { + if (Pred == CmpInst::ICMP_ULT || (Pred == CmpInst::ICMP_UGT && IsExact)) { + // icmp ult (lshr X, ShAmtC), C --> icmp ult X, (C << ShAmtC) + // icmp ugt (lshr exact X, ShAmtC), C --> icmp ugt X, (C << ShAmtC) + APInt ShiftedC = C.shl(ShAmtVal); + if (ShiftedC.lshr(ShAmtVal) == C) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); + } + if (Pred == CmpInst::ICMP_UGT) { + // icmp ugt (lshr X, ShAmtC), C --> icmp ugt X, ((C + 1) << ShAmtC) - 1 + APInt ShiftedC = (C + 1).shl(ShAmtVal) - 1; + if ((ShiftedC + 1).lshr(ShAmtVal) == (C + 1)) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); + } } + if (!Cmp.isEquality()) + return nullptr; + // Handle equality comparisons of shift-by-constant. // If the comparison constant changes with the shift, the comparison cannot // succeed (bits of the comparison constant cannot match the shifted value). // This should be known by InstSimplify and already be folded to true/false. - assert(((IsAShr && C->shl(ShAmtVal).ashr(ShAmtVal) == *C) || - (!IsAShr && C->shl(ShAmtVal).lshr(ShAmtVal) == *C)) && + assert(((IsAShr && C.shl(ShAmtVal).ashr(ShAmtVal) == C) || + (!IsAShr && C.shl(ShAmtVal).lshr(ShAmtVal) == C)) && "Expected icmp+shr simplify did not occur."); - // Check if the bits shifted out are known to be zero. If so, we can compare - // against the unshifted value: + // If the bits shifted out are known zero, compare the unshifted value: // (X & 4) >> 1 == 2 --> (X & 4) == 4. - Constant *ShiftedCmpRHS = ConstantInt::get(Shr->getType(), *C << ShAmtVal); - if (Shr->hasOneUse()) { - if (Shr->isExact()) - return new ICmpInst(Pred, X, ShiftedCmpRHS); + if (Shr->isExact()) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, C << ShAmtVal)); - // Otherwise strength reduce the shift into an 'and'. + if (Shr->hasOneUse()) { + // Canonicalize the shift into an 'and': + // icmp eq/ne (shr X, ShAmt), C --> icmp eq/ne (and X, HiMask), (C << ShAmt) APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); - Constant *Mask = ConstantInt::get(Shr->getType(), Val); + Constant *Mask = ConstantInt::get(ShrTy, Val); Value *And = Builder.CreateAnd(X, Mask, Shr->getName() + ".mask"); - return new ICmpInst(Pred, And, ShiftedCmpRHS); + return new ICmpInst(Pred, And, ConstantInt::get(ShrTy, C << ShAmtVal)); } return nullptr; @@ -2137,7 +2084,7 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, /// Fold icmp (udiv X, Y), C. Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv, - const APInt *C) { + const APInt &C) { const APInt *C2; if (!match(UDiv->getOperand(0), m_APInt(C2))) return nullptr; @@ -2147,17 +2094,17 @@ Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, // (icmp ugt (udiv C2, Y), C) -> (icmp ule Y, C2/(C+1)) Value *Y = UDiv->getOperand(1); if (Cmp.getPredicate() == ICmpInst::ICMP_UGT) { - assert(!C->isMaxValue() && + assert(!C.isMaxValue() && "icmp ugt X, UINT_MAX should have been simplified already."); return new ICmpInst(ICmpInst::ICMP_ULE, Y, - ConstantInt::get(Y->getType(), C2->udiv(*C + 1))); + ConstantInt::get(Y->getType(), C2->udiv(C + 1))); } // (icmp ult (udiv C2, Y), C) -> (icmp ugt Y, C2/C) if (Cmp.getPredicate() == ICmpInst::ICMP_ULT) { - assert(*C != 0 && "icmp ult X, 0 should have been simplified already."); + assert(C != 0 && "icmp ult X, 0 should have been simplified already."); return new ICmpInst(ICmpInst::ICMP_UGT, Y, - ConstantInt::get(Y->getType(), C2->udiv(*C))); + ConstantInt::get(Y->getType(), C2->udiv(C))); } return nullptr; @@ -2166,7 +2113,7 @@ Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, /// Fold icmp ({su}div X, Y), C. Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div, - const APInt *C) { + const APInt &C) { // Fold: icmp pred ([us]div X, C2), C -> range test // Fold this div into the comparison, producing a range check. // Determine, based on the divide type, what the range is being @@ -2197,28 +2144,22 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, (DivIsSigned && C2->isAllOnesValue())) return nullptr; - // TODO: We could do all of the computations below using APInt. - Constant *CmpRHS = cast<Constant>(Cmp.getOperand(1)); - Constant *DivRHS = cast<Constant>(Div->getOperand(1)); - - // Compute Prod = CmpRHS * DivRHS. We are essentially solving an equation of - // form X / C2 = C. We solve for X by multiplying C2 (DivRHS) and C (CmpRHS). + // Compute Prod = C * C2. We are essentially solving an equation of + // form X / C2 = C. We solve for X by multiplying C2 and C. // By solving for X, we can turn this into a range check instead of computing // a divide. - Constant *Prod = ConstantExpr::getMul(CmpRHS, DivRHS); + APInt Prod = C * *C2; // Determine if the product overflows by seeing if the product is not equal to // the divide. Make sure we do the same kind of divide as in the LHS // instruction that we're folding. - bool ProdOV = (DivIsSigned ? ConstantExpr::getSDiv(Prod, DivRHS) - : ConstantExpr::getUDiv(Prod, DivRHS)) != CmpRHS; + bool ProdOV = (DivIsSigned ? Prod.sdiv(*C2) : Prod.udiv(*C2)) != C; ICmpInst::Predicate Pred = Cmp.getPredicate(); // If the division is known to be exact, then there is no remainder from the // divide, so the covered range size is unit, otherwise it is the divisor. - Constant *RangeSize = - Div->isExact() ? ConstantInt::get(Div->getType(), 1) : DivRHS; + APInt RangeSize = Div->isExact() ? APInt(C2->getBitWidth(), 1) : *C2; // Figure out the interval that is being checked. For example, a comparison // like "X /u 5 == 0" is really checking that X is in the interval [0, 5). @@ -2228,7 +2169,7 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, // overflow variable is set to 0 if it's corresponding bound variable is valid // -1 if overflowed off the bottom end, or +1 if overflowed off the top end. int LoOverflow = 0, HiOverflow = 0; - Constant *LoBound = nullptr, *HiBound = nullptr; + APInt LoBound, HiBound; if (!DivIsSigned) { // udiv // e.g. X/5 op 3 --> [15, 20) @@ -2240,38 +2181,38 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, HiOverflow = addWithOverflow(HiBound, LoBound, RangeSize, false); } } else if (C2->isStrictlyPositive()) { // Divisor is > 0. - if (C->isNullValue()) { // (X / pos) op 0 + if (C.isNullValue()) { // (X / pos) op 0 // Can't overflow. e.g. X/2 op 0 --> [-1, 2) - LoBound = ConstantExpr::getNeg(SubOne(RangeSize)); + LoBound = -(RangeSize - 1); HiBound = RangeSize; - } else if (C->isStrictlyPositive()) { // (X / pos) op pos + } else if (C.isStrictlyPositive()) { // (X / pos) op pos LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) HiOverflow = LoOverflow = ProdOV; if (!HiOverflow) HiOverflow = addWithOverflow(HiBound, Prod, RangeSize, true); } else { // (X / pos) op neg // e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14) - HiBound = AddOne(Prod); + HiBound = Prod + 1; LoOverflow = HiOverflow = ProdOV ? -1 : 0; if (!LoOverflow) { - Constant *DivNeg = ConstantExpr::getNeg(RangeSize); + APInt DivNeg = -RangeSize; LoOverflow = addWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0; } } } else if (C2->isNegative()) { // Divisor is < 0. if (Div->isExact()) - RangeSize = ConstantExpr::getNeg(RangeSize); - if (C->isNullValue()) { // (X / neg) op 0 + RangeSize.negate(); + if (C.isNullValue()) { // (X / neg) op 0 // e.g. X/-5 op 0 --> [-4, 5) - LoBound = AddOne(RangeSize); - HiBound = ConstantExpr::getNeg(RangeSize); - if (HiBound == DivRHS) { // -INTMIN = INTMIN + LoBound = RangeSize + 1; + HiBound = -RangeSize; + if (HiBound == *C2) { // -INTMIN = INTMIN HiOverflow = 1; // [INTMIN+1, overflow) - HiBound = nullptr; // e.g. X/INTMIN = 0 --> X > INTMIN + HiBound = APInt(); // e.g. X/INTMIN = 0 --> X > INTMIN } - } else if (C->isStrictlyPositive()) { // (X / neg) op pos + } else if (C.isStrictlyPositive()) { // (X / neg) op pos // e.g. X/-5 op 3 --> [-19, -14) - HiBound = AddOne(Prod); + HiBound = Prod + 1; HiOverflow = LoOverflow = ProdOV ? -1 : 0; if (!LoOverflow) LoOverflow = addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0; @@ -2294,25 +2235,27 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, return replaceInstUsesWith(Cmp, Builder.getFalse()); if (HiOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : - ICmpInst::ICMP_UGE, X, LoBound); + ICmpInst::ICMP_UGE, X, + ConstantInt::get(Div->getType(), LoBound)); if (LoOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : - ICmpInst::ICMP_ULT, X, HiBound); + ICmpInst::ICMP_ULT, X, + ConstantInt::get(Div->getType(), HiBound)); return replaceInstUsesWith( - Cmp, insertRangeTest(X, LoBound->getUniqueInteger(), - HiBound->getUniqueInteger(), DivIsSigned, true)); + Cmp, insertRangeTest(X, LoBound, HiBound, DivIsSigned, true)); case ICmpInst::ICMP_NE: if (LoOverflow && HiOverflow) return replaceInstUsesWith(Cmp, Builder.getTrue()); if (HiOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : - ICmpInst::ICMP_ULT, X, LoBound); + ICmpInst::ICMP_ULT, X, + ConstantInt::get(Div->getType(), LoBound)); if (LoOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : - ICmpInst::ICMP_UGE, X, HiBound); + ICmpInst::ICMP_UGE, X, + ConstantInt::get(Div->getType(), HiBound)); return replaceInstUsesWith(Cmp, - insertRangeTest(X, LoBound->getUniqueInteger(), - HiBound->getUniqueInteger(), + insertRangeTest(X, LoBound, HiBound, DivIsSigned, false)); case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_SLT: @@ -2320,7 +2263,7 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, return replaceInstUsesWith(Cmp, Builder.getTrue()); if (LoOverflow == -1) // Low bound is less than input range. return replaceInstUsesWith(Cmp, Builder.getFalse()); - return new ICmpInst(Pred, X, LoBound); + return new ICmpInst(Pred, X, ConstantInt::get(Div->getType(), LoBound)); case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_SGT: if (HiOverflow == +1) // High bound greater than input range. @@ -2328,8 +2271,10 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, if (HiOverflow == -1) // High bound less than input range. return replaceInstUsesWith(Cmp, Builder.getTrue()); if (Pred == ICmpInst::ICMP_UGT) - return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound); - return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); + return new ICmpInst(ICmpInst::ICMP_UGE, X, + ConstantInt::get(Div->getType(), HiBound)); + return new ICmpInst(ICmpInst::ICMP_SGE, X, + ConstantInt::get(Div->getType(), HiBound)); } return nullptr; @@ -2338,7 +2283,7 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, /// Fold icmp (sub X, Y), C. Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, BinaryOperator *Sub, - const APInt *C) { + const APInt &C) { Value *X = Sub->getOperand(0), *Y = Sub->getOperand(1); ICmpInst::Predicate Pred = Cmp.getPredicate(); @@ -2349,19 +2294,19 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, if (Sub->hasNoSignedWrap()) { // (icmp sgt (sub nsw X, Y), -1) -> (icmp sge X, Y) - if (Pred == ICmpInst::ICMP_SGT && C->isAllOnesValue()) + if (Pred == ICmpInst::ICMP_SGT && C.isAllOnesValue()) return new ICmpInst(ICmpInst::ICMP_SGE, X, Y); // (icmp sgt (sub nsw X, Y), 0) -> (icmp sgt X, Y) - if (Pred == ICmpInst::ICMP_SGT && C->isNullValue()) + if (Pred == ICmpInst::ICMP_SGT && C.isNullValue()) return new ICmpInst(ICmpInst::ICMP_SGT, X, Y); // (icmp slt (sub nsw X, Y), 0) -> (icmp slt X, Y) - if (Pred == ICmpInst::ICMP_SLT && C->isNullValue()) + if (Pred == ICmpInst::ICMP_SLT && C.isNullValue()) return new ICmpInst(ICmpInst::ICMP_SLT, X, Y); // (icmp slt (sub nsw X, Y), 1) -> (icmp sle X, Y) - if (Pred == ICmpInst::ICMP_SLT && C->isOneValue()) + if (Pred == ICmpInst::ICMP_SLT && C.isOneValue()) return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); } @@ -2371,14 +2316,14 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, // C2 - Y <u C -> (Y | (C - 1)) == C2 // iff (C2 & (C - 1)) == C - 1 and C is a power of 2 - if (Pred == ICmpInst::ICMP_ULT && C->isPowerOf2() && - (*C2 & (*C - 1)) == (*C - 1)) - return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateOr(Y, *C - 1), X); + if (Pred == ICmpInst::ICMP_ULT && C.isPowerOf2() && + (*C2 & (C - 1)) == (C - 1)) + return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateOr(Y, C - 1), X); // C2 - Y >u C -> (Y | C) != C2 // iff C2 & C == C and C + 1 is a power of 2 - if (Pred == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && (*C2 & *C) == *C) - return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateOr(Y, *C), X); + if (Pred == ICmpInst::ICMP_UGT && (C + 1).isPowerOf2() && (*C2 & C) == C) + return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateOr(Y, C), X); return nullptr; } @@ -2386,7 +2331,7 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, /// Fold icmp (add X, Y), C. Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, BinaryOperator *Add, - const APInt *C) { + const APInt &C) { Value *Y = Add->getOperand(1); const APInt *C2; if (Cmp.isEquality() || !match(Y, m_APInt(C2))) @@ -2403,7 +2348,7 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, if (Add->hasNoSignedWrap() && (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) { bool Overflow; - APInt NewC = C->ssub_ov(*C2, Overflow); + APInt NewC = C.ssub_ov(*C2, Overflow); // If there is overflow, the result must be true or false. // TODO: Can we assert there is no overflow because InstSimplify always // handles those cases? @@ -2412,7 +2357,7 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, return new ICmpInst(Pred, X, ConstantInt::get(Ty, NewC)); } - auto CR = ConstantRange::makeExactICmpRegion(Pred, *C).subtract(*C2); + auto CR = ConstantRange::makeExactICmpRegion(Pred, C).subtract(*C2); const APInt &Upper = CR.getUpper(); const APInt &Lower = CR.getLower(); if (Cmp.isSigned()) { @@ -2433,15 +2378,15 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, // X+C <u C2 -> (X & -C2) == C // iff C & (C2-1) == 0 // C2 is a power of 2 - if (Pred == ICmpInst::ICMP_ULT && C->isPowerOf2() && (*C2 & (*C - 1)) == 0) - return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateAnd(X, -(*C)), + if (Pred == ICmpInst::ICMP_ULT && C.isPowerOf2() && (*C2 & (C - 1)) == 0) + return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateAnd(X, -C), ConstantExpr::getNeg(cast<Constant>(Y))); // X+C >u C2 -> (X & ~C2) != C // iff C & C2 == 0 // C2+1 is a power of 2 - if (Pred == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && (*C2 & *C) == 0) - return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateAnd(X, ~(*C)), + if (Pred == ICmpInst::ICMP_UGT && (C + 1).isPowerOf2() && (*C2 & C) == 0) + return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateAnd(X, ~C), ConstantExpr::getNeg(cast<Constant>(Y))); return nullptr; @@ -2471,7 +2416,7 @@ bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, } Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp, - Instruction *Select, + SelectInst *Select, ConstantInt *C) { assert(C && "Cmp RHS should be a constant int!"); @@ -2483,8 +2428,8 @@ Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp, Value *OrigLHS, *OrigRHS; ConstantInt *C1LessThan, *C2Equal, *C3GreaterThan; if (Cmp.hasOneUse() && - matchThreeWayIntCompare(cast<SelectInst>(Select), OrigLHS, OrigRHS, - C1LessThan, C2Equal, C3GreaterThan)) { + matchThreeWayIntCompare(Select, OrigLHS, OrigRHS, C1LessThan, C2Equal, + C3GreaterThan)) { assert(C1LessThan && C2Equal && C3GreaterThan); bool TrueWhenLessThan = @@ -2525,82 +2470,74 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { if (!match(Cmp.getOperand(1), m_APInt(C))) return nullptr; - BinaryOperator *BO; - if (match(Cmp.getOperand(0), m_BinOp(BO))) { + if (auto *BO = dyn_cast<BinaryOperator>(Cmp.getOperand(0))) { switch (BO->getOpcode()) { case Instruction::Xor: - if (Instruction *I = foldICmpXorConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpXorConstant(Cmp, BO, *C)) return I; break; case Instruction::And: - if (Instruction *I = foldICmpAndConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpAndConstant(Cmp, BO, *C)) return I; break; case Instruction::Or: - if (Instruction *I = foldICmpOrConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpOrConstant(Cmp, BO, *C)) return I; break; case Instruction::Mul: - if (Instruction *I = foldICmpMulConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpMulConstant(Cmp, BO, *C)) return I; break; case Instruction::Shl: - if (Instruction *I = foldICmpShlConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpShlConstant(Cmp, BO, *C)) return I; break; case Instruction::LShr: case Instruction::AShr: - if (Instruction *I = foldICmpShrConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpShrConstant(Cmp, BO, *C)) return I; break; case Instruction::UDiv: - if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpUDivConstant(Cmp, BO, *C)) return I; LLVM_FALLTHROUGH; case Instruction::SDiv: - if (Instruction *I = foldICmpDivConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpDivConstant(Cmp, BO, *C)) return I; break; case Instruction::Sub: - if (Instruction *I = foldICmpSubConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpSubConstant(Cmp, BO, *C)) return I; break; case Instruction::Add: - if (Instruction *I = foldICmpAddConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpAddConstant(Cmp, BO, *C)) return I; break; default: break; } // TODO: These folds could be refactored to be part of the above calls. - if (Instruction *I = foldICmpBinOpEqualityWithConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpBinOpEqualityWithConstant(Cmp, BO, *C)) return I; } // Match against CmpInst LHS being instructions other than binary operators. - Instruction *LHSI; - if (match(Cmp.getOperand(0), m_Instruction(LHSI))) { - switch (LHSI->getOpcode()) { - case Instruction::Select: - { - // For now, we only support constant integers while folding the - // ICMP(SELECT)) pattern. We can extend this to support vector of integers - // similar to the cases handled by binary ops above. - if (ConstantInt *ConstRHS = dyn_cast<ConstantInt>(Cmp.getOperand(1))) - if (Instruction *I = foldICmpSelectConstant(Cmp, LHSI, ConstRHS)) - return I; - break; - } - case Instruction::Trunc: - if (Instruction *I = foldICmpTruncConstant(Cmp, LHSI, C)) + + if (auto *SI = dyn_cast<SelectInst>(Cmp.getOperand(0))) { + // For now, we only support constant integers while folding the + // ICMP(SELECT)) pattern. We can extend this to support vector of integers + // similar to the cases handled by binary ops above. + if (ConstantInt *ConstRHS = dyn_cast<ConstantInt>(Cmp.getOperand(1))) + if (Instruction *I = foldICmpSelectConstant(Cmp, SI, ConstRHS)) return I; - break; - default: - break; - } } - if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, C)) + if (auto *TI = dyn_cast<TruncInst>(Cmp.getOperand(0))) { + if (Instruction *I = foldICmpTruncConstant(Cmp, TI, *C)) + return I; + } + + if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, *C)) return I; return nullptr; @@ -2610,7 +2547,7 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { /// icmp eq/ne BO, C. Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, BinaryOperator *BO, - const APInt *C) { + const APInt &C) { // TODO: Some of these folds could work with arbitrary constants, but this // function is limited to scalar and vector splat constants. if (!Cmp.isEquality()) @@ -2624,7 +2561,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, switch (BO->getOpcode()) { case Instruction::SRem: // If we have a signed (X % (2^c)) == 0, turn it into an unsigned one. - if (C->isNullValue() && BO->hasOneUse()) { + if (C.isNullValue() && BO->hasOneUse()) { const APInt *BOC; if (match(BOp1, m_APInt(BOC)) && BOC->sgt(1) && BOC->isPowerOf2()) { Value *NewRem = Builder.CreateURem(BOp0, BOp1, BO->getName()); @@ -2641,7 +2578,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, Constant *SubC = ConstantExpr::getSub(RHS, cast<Constant>(BOp1)); return new ICmpInst(Pred, BOp0, SubC); } - } else if (C->isNullValue()) { + } else if (C.isNullValue()) { // Replace ((add A, B) != 0) with (A != -B) if A or B is // efficiently invertible, or if the add has just this one use. if (Value *NegVal = dyn_castNegVal(BOp1)) @@ -2662,7 +2599,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, // For the xor case, we can xor two constants together, eliminating // the explicit xor. return new ICmpInst(Pred, BOp0, ConstantExpr::getXor(RHS, BOC)); - } else if (C->isNullValue()) { + } else if (C.isNullValue()) { // Replace ((xor A, B) != 0) with (A != B) return new ICmpInst(Pred, BOp0, BOp1); } @@ -2675,7 +2612,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, // Replace ((sub BOC, B) != C) with (B != BOC-C). Constant *SubC = ConstantExpr::getSub(cast<Constant>(BOp0), RHS); return new ICmpInst(Pred, BOp1, SubC); - } else if (C->isNullValue()) { + } else if (C.isNullValue()) { // Replace ((sub A, B) != 0) with (A != B). return new ICmpInst(Pred, BOp0, BOp1); } @@ -2697,7 +2634,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, const APInt *BOC; if (match(BOp1, m_APInt(BOC))) { // If we have ((X & C) == C), turn it into ((X & C) != 0). - if (C == BOC && C->isPowerOf2()) + if (C == *BOC && C.isPowerOf2()) return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, BO, Constant::getNullValue(RHS->getType())); @@ -2713,7 +2650,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, } // ((X & ~7) == 0) --> X < 8 - if (C->isNullValue() && (~(*BOC) + 1).isPowerOf2()) { + if (C.isNullValue() && (~(*BOC) + 1).isPowerOf2()) { Constant *NegBOC = ConstantExpr::getNeg(cast<Constant>(BOp1)); auto NewPred = isICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; return new ICmpInst(NewPred, BOp0, NegBOC); @@ -2722,7 +2659,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, break; } case Instruction::Mul: - if (C->isNullValue() && BO->hasNoSignedWrap()) { + if (C.isNullValue() && BO->hasNoSignedWrap()) { const APInt *BOC; if (match(BOp1, m_APInt(BOC)) && !BOC->isNullValue()) { // The trivial case (mul X, 0) is handled by InstSimplify. @@ -2733,7 +2670,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, } break; case Instruction::UDiv: - if (C->isNullValue()) { + if (C.isNullValue()) { // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) auto NewPred = isICMP_NE ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; return new ICmpInst(NewPred, BOp1, BOp0); @@ -2747,7 +2684,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, /// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, - const APInt *C) { + const APInt &C) { IntrinsicInst *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0)); if (!II || !Cmp.isEquality()) return nullptr; @@ -2758,13 +2695,13 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, case Intrinsic::bswap: Worklist.Add(II); Cmp.setOperand(0, II->getArgOperand(0)); - Cmp.setOperand(1, ConstantInt::get(Ty, C->byteSwap())); + Cmp.setOperand(1, ConstantInt::get(Ty, C.byteSwap())); return &Cmp; case Intrinsic::ctlz: case Intrinsic::cttz: // ctz(A) == bitwidth(A) -> A == 0 and likewise for != - if (*C == C->getBitWidth()) { + if (C == C.getBitWidth()) { Worklist.Add(II); Cmp.setOperand(0, II->getArgOperand(0)); Cmp.setOperand(1, ConstantInt::getNullValue(Ty)); @@ -2775,8 +2712,8 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, case Intrinsic::ctpop: { // popcount(A) == 0 -> A == 0 and likewise for != // popcount(A) == bitwidth(A) -> A == -1 and likewise for != - bool IsZero = C->isNullValue(); - if (IsZero || *C == C->getBitWidth()) { + bool IsZero = C.isNullValue(); + if (IsZero || C == C.getBitWidth()) { Worklist.Add(II); Cmp.setOperand(0, II->getArgOperand(0)); auto *NewOp = @@ -3924,31 +3861,29 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, /// When performing a comparison against a constant, it is possible that not all /// the bits in the LHS are demanded. This helper method computes the mask that /// IS demanded. -static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth, - bool isSignCheck) { - if (isSignCheck) - return APInt::getSignMask(BitWidth); +static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) { + const APInt *RHS; + if (!match(I.getOperand(1), m_APInt(RHS))) + return APInt::getAllOnesValue(BitWidth); - ConstantInt *CI = dyn_cast<ConstantInt>(I.getOperand(1)); - if (!CI) return APInt::getAllOnesValue(BitWidth); - const APInt &RHS = CI->getValue(); + // If this is a normal comparison, it demands all bits. If it is a sign bit + // comparison, it only demands the sign bit. + bool UnusedBit; + if (isSignBitCheck(I.getPredicate(), *RHS, UnusedBit)) + return APInt::getSignMask(BitWidth); switch (I.getPredicate()) { // For a UGT comparison, we don't care about any bits that // correspond to the trailing ones of the comparand. The value of these // bits doesn't impact the outcome of the comparison, because any value // greater than the RHS must differ in a bit higher than these due to carry. - case ICmpInst::ICMP_UGT: { - unsigned trailingOnes = RHS.countTrailingOnes(); - return APInt::getBitsSetFrom(BitWidth, trailingOnes); - } + case ICmpInst::ICMP_UGT: + return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingOnes()); // Similarly, for a ULT comparison, we don't care about the trailing zeros. // Any value less than the RHS must differ in a higher bit because of carries. - case ICmpInst::ICMP_ULT: { - unsigned trailingZeros = RHS.countTrailingZeros(); - return APInt::getBitsSetFrom(BitWidth, trailingZeros); - } + case ICmpInst::ICMP_ULT: + return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingZeros()); default: return APInt::getAllOnesValue(BitWidth); @@ -4122,20 +4057,11 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { if (!BitWidth) return nullptr; - // If this is a normal comparison, it demands all bits. If it is a sign bit - // comparison, it only demands the sign bit. - bool IsSignBit = false; - const APInt *CmpC; - if (match(Op1, m_APInt(CmpC))) { - bool UnusedBit; - IsSignBit = isSignBitCheck(Pred, *CmpC, UnusedBit); - } - KnownBits Op0Known(BitWidth); KnownBits Op1Known(BitWidth); if (SimplifyDemandedBits(&I, 0, - getDemandedBitsLHSMask(I, BitWidth, IsSignBit), + getDemandedBitsLHSMask(I, BitWidth), Op0Known, 0)) return &I; @@ -4233,20 +4159,22 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { const APInt *CmpC; if (match(Op1, m_APInt(CmpC))) { // A <u C -> A == C-1 if min(A)+1 == C - if (Op1Max == Op0Min + 1) { - Constant *CMinus1 = ConstantInt::get(Op0->getType(), *CmpC - 1); - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, CMinus1); - } + if (*CmpC == Op0Min + 1) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + ConstantInt::get(Op1->getType(), *CmpC - 1)); + // X <u C --> X == 0, if the number of zero bits in the bottom of X + // exceeds the log2 of C. + if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2()) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + Constant::getNullValue(Op1->getType())); } break; } case ICmpInst::ICMP_UGT: { if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); @@ -4256,42 +4184,52 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { if (*CmpC == Op0Max - 1) return new ICmpInst(ICmpInst::ICMP_EQ, Op0, ConstantInt::get(Op1->getType(), *CmpC + 1)); + // X >u C --> X != 0, if the number of zero bits in the bottom of X + // exceeds the log2 of C. + if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits()) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, + Constant::getNullValue(Op1->getType())); } break; } - case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLT: { if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Max == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder.getInt(CI->getValue() - 1)); + ConstantInt::get(Op1->getType(), *CmpC - 1)); } break; - case ICmpInst::ICMP_SGT: + } + case ICmpInst::ICMP_SGT: { if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Min == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder.getInt(CI->getValue() + 1)); + ConstantInt::get(Op1->getType(), *CmpC + 1)); } break; + } case ICmpInst::ICMP_SGE: assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!"); if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Min == Op0Max) // A >=s B -> A == B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); break; case ICmpInst::ICMP_SLE: assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!"); @@ -4299,6 +4237,8 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Max == Op0Min) // A <=s B -> A == B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); break; case ICmpInst::ICMP_UGE: assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!"); @@ -4306,6 +4246,8 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Min == Op0Max) // A >=u B -> A == B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); break; case ICmpInst::ICMP_ULE: assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!"); @@ -4313,6 +4255,8 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Max == Op0Min) // A <=u B -> A == B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); break; } @@ -4478,7 +4422,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - // comparing -val or val with non-zero is the same as just comparing val + // Comparing -val or val with non-zero is the same as just comparing val // ie, abs(val) != 0 -> val != 0 if (I.getPredicate() == ICmpInst::ICMP_NE && match(Op1, m_Zero())) { Value *Cond, *SelectTrue, *SelectFalse; @@ -4515,11 +4459,19 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // and CodeGen. And in this case, at least one of the comparison // operands has at least one user besides the compare (the select), // which would often largely negate the benefit of folding anyway. + // + // Do the same for the other patterns recognized by matchSelectPattern. if (I.hasOneUse()) - if (SelectInst *SI = dyn_cast<SelectInst>(*I.user_begin())) - if ((SI->getOperand(1) == Op0 && SI->getOperand(2) == Op1) || - (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1)) + if (SelectInst *SI = dyn_cast<SelectInst>(I.user_back())) { + Value *A, *B; + SelectPatternResult SPR = matchSelectPattern(SI, A, B); + if (SPR.Flavor != SPF_UNKNOWN) return nullptr; + } + + // Do this after checking for min/max to prevent infinite looping. + if (Instruction *Res = foldICmpWithZero(I)) + return Res; // FIXME: We only do this after checking for min/max to prevent infinite // looping caused by a reverse canonicalization of these patterns for min/max. @@ -4684,11 +4636,11 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Value *X; ConstantInt *Cst; // icmp X+Cst, X if (match(Op0, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op1 == X) - return foldICmpAddOpConst(I, X, Cst, I.getPredicate()); + return foldICmpAddOpConst(X, Cst, I.getPredicate()); // icmp X, X+Cst if (match(Op1, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op0 == X) - return foldICmpAddOpConst(I, X, Cst, I.getSwappedPredicate()); + return foldICmpAddOpConst(X, Cst, I.getSwappedPredicate()); } return Changed ? &I : nullptr; } @@ -4943,17 +4895,16 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { Changed = true; } + const CmpInst::Predicate Pred = I.getPredicate(); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = - SimplifyFCmpInst(I.getPredicate(), Op0, Op1, I.getFastMathFlags(), - SQ.getWithInstruction(&I))) + if (Value *V = SimplifyFCmpInst(Pred, Op0, Op1, I.getFastMathFlags(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); // Simplify 'fcmp pred X, X' if (Op0 == Op1) { - switch (I.getPredicate()) { - default: llvm_unreachable("Unknown predicate!"); + switch (Pred) { + default: break; case FCmpInst::FCMP_UNO: // True if unordered: isnan(X) | isnan(Y) case FCmpInst::FCMP_ULT: // True if unordered or less than case FCmpInst::FCMP_UGT: // True if unordered or greater than @@ -4974,6 +4925,19 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { } } + // If we're just checking for a NaN (ORD/UNO) and have a non-NaN operand, + // then canonicalize the operand to 0.0. + if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) { + if (!match(Op0, m_Zero()) && isKnownNeverNaN(Op0)) { + I.setOperand(0, ConstantFP::getNullValue(Op0->getType())); + return &I; + } + if (!match(Op1, m_Zero()) && isKnownNeverNaN(Op1)) { + I.setOperand(1, ConstantFP::getNullValue(Op0->getType())); + return &I; + } + } + // Test if the FCmpInst instruction is used exclusively by a select as // part of a minimum or maximum operation. If so, refrain from doing // any other folding. This helps out other analyses which understand @@ -4982,10 +4946,12 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { // operands has at least one user besides the compare (the select), // which would often largely negate the benefit of folding anyway. if (I.hasOneUse()) - if (SelectInst *SI = dyn_cast<SelectInst>(*I.user_begin())) - if ((SI->getOperand(1) == Op0 && SI->getOperand(2) == Op1) || - (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1)) + if (SelectInst *SI = dyn_cast<SelectInst>(I.user_back())) { + Value *A, *B; + SelectPatternResult SPR = matchSelectPattern(SI, A, B); + if (SPR.Flavor != SPF_UNKNOWN) return nullptr; + } // Handle fcmp with constant RHS if (Constant *RHSC = dyn_cast<Constant>(Op1)) { @@ -5027,7 +4993,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { ((Fabs.compare(APFloat::getSmallestNormalized(*Sem)) != APFloat::cmpLessThan) || Fabs.isZero())) - return new FCmpInst(I.getPredicate(), LHSExt->getOperand(0), + return new FCmpInst(Pred, LHSExt->getOperand(0), ConstantFP::get(RHSC->getContext(), F)); break; } @@ -5072,7 +5038,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { break; // Various optimization for fabs compared with zero. - switch (I.getPredicate()) { + switch (Pred) { default: break; // fabs(x) < 0 --> false @@ -5093,7 +5059,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { case FCmpInst::FCMP_UEQ: case FCmpInst::FCMP_ONE: case FCmpInst::FCMP_UNE: - return new FCmpInst(I.getPredicate(), CI->getArgOperand(0), RHSC); + return new FCmpInst(Pred, CI->getArgOperand(0), RHSC); } } } @@ -5108,8 +5074,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { if (FPExtInst *LHSExt = dyn_cast<FPExtInst>(Op0)) if (FPExtInst *RHSExt = dyn_cast<FPExtInst>(Op1)) if (LHSExt->getSrcTy() == RHSExt->getSrcTy()) - return new FCmpInst(I.getPredicate(), LHSExt->getOperand(0), - RHSExt->getOperand(0)); + return new FCmpInst(Pred, LHSExt->getOperand(0), RHSExt->getOperand(0)); return Changed ? &I : nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index c38a4981bf1d..f1f66d86cb73 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -6,42 +6,59 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// +// /// \file /// /// This file provides internal interfaces used to implement the InstCombine. -/// +// //===----------------------------------------------------------------------===// #ifndef LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEINTERNAL_H #define LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEINTERNAL_H +#include "llvm/ADT/ArrayRef.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/Dominators.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Operator.h" -#include "llvm/IR/PatternMatch.h" -#include "llvm/Pass.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/KnownBits.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" #include "llvm/Transforms/Utils/Local.h" +#include <cassert> +#include <cstdint> #define DEBUG_TYPE "instcombine" namespace llvm { + +class APInt; +class AssumptionCache; class CallSite; class DataLayout; class DominatorTree; +class GEPOperator; +class GlobalVariable; +class LoopInfo; +class OptimizationRemarkEmitter; class TargetLibraryInfo; -class DbgDeclareInst; -class MemIntrinsic; -class MemSetInst; +class User; /// Assign a complexity or rank value to LLVM Values. This is used to reduce /// the amount of pattern matching needed for compares and commutative @@ -109,6 +126,7 @@ static inline Value *peekThroughBitcast(Value *V, bool OneUseOnly = false) { static inline Constant *AddOne(Constant *C) { return ConstantExpr::getAdd(C, ConstantInt::get(C->getType(), 1)); } + /// \brief Subtract one from a Constant static inline Constant *SubOne(Constant *C) { return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1)); @@ -118,7 +136,6 @@ static inline Constant *SubOne(Constant *C) { /// This happens in cases where the ~ can be eliminated. If WillInvertAllUses /// is true, work under the assumption that the caller intends to remove all /// uses of V and only keep uses of ~V. -/// static inline bool IsFreeToInvert(Value *V, bool WillInvertAllUses) { // ~(~(X)) -> X. if (BinaryOperator::isNot(V)) @@ -161,7 +178,6 @@ static inline bool IsFreeToInvert(Value *V, bool WillInvertAllUses) { return false; } - /// \brief Specific patterns of overflow check idioms that we match. enum OverflowCheckFlavor { OCF_UNSIGNED_ADD, @@ -209,12 +225,13 @@ public: /// \brief An IRBuilder that automatically inserts new instructions into the /// worklist. - typedef IRBuilder<TargetFolder, IRBuilderCallbackInserter> BuilderTy; + using BuilderTy = IRBuilder<TargetFolder, IRBuilderCallbackInserter>; BuilderTy &Builder; private: // Mode in which we are running the combiner. const bool MinimizeSize; + /// Enable combines that trigger rarely but are costly in compiletime. const bool ExpensiveCombines; @@ -226,20 +243,23 @@ private: DominatorTree &DT; const DataLayout &DL; const SimplifyQuery SQ; + OptimizationRemarkEmitter &ORE; + // Optional analyses. When non-null, these can both be used to do better // combining and will be updated to reflect any changes. LoopInfo *LI; - bool MadeIRChange; + bool MadeIRChange = false; public: InstCombiner(InstCombineWorklist &Worklist, BuilderTy &Builder, bool MinimizeSize, bool ExpensiveCombines, AliasAnalysis *AA, AssumptionCache &AC, TargetLibraryInfo &TLI, DominatorTree &DT, - const DataLayout &DL, LoopInfo *LI) + OptimizationRemarkEmitter &ORE, const DataLayout &DL, + LoopInfo *LI) : Worklist(Worklist), Builder(Builder), MinimizeSize(MinimizeSize), ExpensiveCombines(ExpensiveCombines), AA(AA), AC(AC), TLI(TLI), DT(DT), - DL(DL), SQ(DL, &TLI, &DT, &AC), LI(LI), MadeIRChange(false) {} + DL(DL), SQ(DL, &TLI, &DT, &AC), ORE(ORE), LI(LI) {} /// \brief Run the combiner over the entire worklist until it is empty. /// @@ -275,7 +295,7 @@ public: Instruction *visitURem(BinaryOperator &I); Instruction *visitSRem(BinaryOperator &I); Instruction *visitFRem(BinaryOperator &I); - bool SimplifyDivRemOfSelect(BinaryOperator &I); + bool simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I); Instruction *commonRemTransforms(BinaryOperator &I); Instruction *commonIRemTransforms(BinaryOperator &I); Instruction *commonDivTransforms(BinaryOperator &I); @@ -411,32 +431,38 @@ private: bool DoTransform = true); Instruction *transformSExtICmp(ICmpInst *ICI, Instruction &CI); + bool willNotOverflowSignedAdd(const Value *LHS, const Value *RHS, const Instruction &CxtI) const { return computeOverflowForSignedAdd(LHS, RHS, &CxtI) == OverflowResult::NeverOverflows; - }; + } + bool willNotOverflowUnsignedAdd(const Value *LHS, const Value *RHS, const Instruction &CxtI) const { return computeOverflowForUnsignedAdd(LHS, RHS, &CxtI) == OverflowResult::NeverOverflows; - }; + } + bool willNotOverflowSignedSub(const Value *LHS, const Value *RHS, const Instruction &CxtI) const; bool willNotOverflowUnsignedSub(const Value *LHS, const Value *RHS, const Instruction &CxtI) const; bool willNotOverflowSignedMul(const Value *LHS, const Value *RHS, const Instruction &CxtI) const; + bool willNotOverflowUnsignedMul(const Value *LHS, const Value *RHS, const Instruction &CxtI) const { return computeOverflowForUnsignedMul(LHS, RHS, &CxtI) == OverflowResult::NeverOverflows; - }; + } + Value *EmitGEPOffset(User *GEP); Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN); Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask); Instruction *foldCastedBitwiseLogic(BinaryOperator &I); - Instruction *shrinkBitwiseLogic(TruncInst &Trunc); + Instruction *narrowBinOp(TruncInst &Trunc); + Instruction *narrowRotate(TruncInst &Trunc); Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); /// Determine if a pair of casts can be replaced by a single cast. @@ -453,11 +479,14 @@ private: const CastInst *CI2); Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI); - Value *foldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS); Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI); - Value *foldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS); Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS); + /// Optimize (fcmp)&(fcmp) or (fcmp)|(fcmp). + /// NOTE: Unlike most of instcombine, this returns a Value which should + /// already be inserted into the function. + Value *foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd); + Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, bool JoinedByAnd, Instruction &CxtI); public: @@ -542,6 +571,7 @@ public: unsigned Depth, const Instruction *CxtI) const { llvm::computeKnownBits(V, Known, DL, Depth, &AC, CxtI, &DT); } + KnownBits computeKnownBits(const Value *V, unsigned Depth, const Instruction *CxtI) const { return llvm::computeKnownBits(V, DL, Depth, &AC, CxtI, &DT); @@ -557,20 +587,24 @@ public: const Instruction *CxtI = nullptr) const { return llvm::MaskedValueIsZero(V, Mask, DL, Depth, &AC, CxtI, &DT); } + unsigned ComputeNumSignBits(const Value *Op, unsigned Depth = 0, const Instruction *CxtI = nullptr) const { return llvm::ComputeNumSignBits(Op, DL, Depth, &AC, CxtI, &DT); } + OverflowResult computeOverflowForUnsignedMul(const Value *LHS, const Value *RHS, const Instruction *CxtI) const { return llvm::computeOverflowForUnsignedMul(LHS, RHS, DL, &AC, CxtI, &DT); } + OverflowResult computeOverflowForUnsignedAdd(const Value *LHS, const Value *RHS, const Instruction *CxtI) const { return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, &AC, CxtI, &DT); } + OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS, const Instruction *CxtI) const { @@ -594,6 +628,11 @@ private: /// value, or null if it didn't simplify. Value *SimplifyUsingDistributiveLaws(BinaryOperator &I); + // Binary Op helper for select operations where the expression can be + // efficiently reorganized. + Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS, + Value *RHS); + /// This tries to simplify binary operations by factorizing out common terms /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). Value *tryFactorization(BinaryOperator &, Instruction::BinaryOps, Value *, @@ -615,6 +654,7 @@ private: bool SimplifyDemandedBits(Instruction *I, unsigned Op, const APInt &DemandedMask, KnownBits &Known, unsigned Depth = 0); + /// Helper routine of SimplifyDemandedUseBits. It computes KnownZero/KnownOne /// bits. It also tries to handle simplifications that can be done based on /// DemandedMask, but without modifying the Instruction. @@ -622,6 +662,7 @@ private: const APInt &DemandedMask, KnownBits &Known, unsigned Depth, Instruction *CxtI); + /// Helper routine of SimplifyDemandedUseBits. It tries to simplify demanded /// bit for "r1 = shr x, c1; r2 = shl r1, c2" instruction sequence. Value *simplifyShrShlDemandedBits( @@ -652,6 +693,8 @@ private: /// This is a convenience wrapper function for the above two functions. Instruction *foldOpWithConstantIntoOperand(BinaryOperator &I); + Instruction *foldAddWithConstant(BinaryOperator &Add); + /// \brief Try to rotate an operation below a PHI node, using PHI nodes for /// its operands. Instruction *FoldPHIArgOpIntoPHI(PHINode &PN); @@ -660,9 +703,14 @@ private: Instruction *FoldPHIArgLoadIntoPHI(PHINode &PN); Instruction *FoldPHIArgZextsIntoPHI(PHINode &PN); - /// Helper function for FoldPHIArgXIntoPHI() to get debug location for the + /// If an integer typed PHI has only one use which is an IntToPtr operation, + /// replace the PHI with an existing pointer typed PHI if it exists. Otherwise + /// insert a new pointer typed PHI and replace the original one. + Instruction *FoldIntegerTypedPHI(PHINode &PN); + + /// Helper function for FoldPHIArgXIntoPHI() to set debug location for the /// folded operation. - DebugLoc PHIArgMergedDebugLoc(PHINode &PN); + void PHIArgMergedDebugLoc(Instruction *Inst, PHINode &PN); Instruction *foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, ICmpInst::Predicate Cond, Instruction &I); @@ -673,7 +721,7 @@ private: ConstantInt *AndCst = nullptr); Instruction *foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, Constant *RHSC); - Instruction *foldICmpAddOpConst(Instruction &ICI, Value *X, ConstantInt *CI, + Instruction *foldICmpAddOpConst(Value *X, ConstantInt *CI, ICmpInst::Predicate Pred); Instruction *foldICmpWithCastAndCast(ICmpInst &ICI); @@ -683,35 +731,36 @@ private: Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp); Instruction *foldICmpBinOp(ICmpInst &Cmp); Instruction *foldICmpEquality(ICmpInst &Cmp); + Instruction *foldICmpWithZero(ICmpInst &Cmp); - Instruction *foldICmpSelectConstant(ICmpInst &Cmp, Instruction *Select, + Instruction *foldICmpSelectConstant(ICmpInst &Cmp, SelectInst *Select, ConstantInt *C); - Instruction *foldICmpTruncConstant(ICmpInst &Cmp, Instruction *Trunc, - const APInt *C); + Instruction *foldICmpTruncConstant(ICmpInst &Cmp, TruncInst *Trunc, + const APInt &C); Instruction *foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And, - const APInt *C); + const APInt &C); Instruction *foldICmpXorConstant(ICmpInst &Cmp, BinaryOperator *Xor, - const APInt *C); + const APInt &C); Instruction *foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, - const APInt *C); + const APInt &C); Instruction *foldICmpMulConstant(ICmpInst &Cmp, BinaryOperator *Mul, - const APInt *C); + const APInt &C); Instruction *foldICmpShlConstant(ICmpInst &Cmp, BinaryOperator *Shl, - const APInt *C); + const APInt &C); Instruction *foldICmpShrConstant(ICmpInst &Cmp, BinaryOperator *Shr, - const APInt *C); + const APInt &C); Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv, - const APInt *C); + const APInt &C); Instruction *foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div, - const APInt *C); + const APInt &C); Instruction *foldICmpSubConstant(ICmpInst &Cmp, BinaryOperator *Sub, - const APInt *C); + const APInt &C); Instruction *foldICmpAddConstant(ICmpInst &Cmp, BinaryOperator *Add, - const APInt *C); + const APInt &C); Instruction *foldICmpAndConstConst(ICmpInst &Cmp, BinaryOperator *And, - const APInt *C1); + const APInt &C1); Instruction *foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, - const APInt *C1, const APInt *C2); + const APInt &C1, const APInt &C2); Instruction *foldICmpShrConstConst(ICmpInst &I, Value *ShAmt, const APInt &C1, const APInt &C2); Instruction *foldICmpShlConstConst(ICmpInst &I, Value *ShAmt, const APInt &C1, @@ -719,8 +768,8 @@ private: Instruction *foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, BinaryOperator *BO, - const APInt *C); - Instruction *foldICmpIntrinsicWithConstant(ICmpInst &ICI, const APInt *C); + const APInt &C); + Instruction *foldICmpIntrinsicWithConstant(ICmpInst &ICI, const APInt &C); // Helpers of visitSelectInst(). Instruction *foldSelectExtConst(SelectInst &Sel); @@ -740,8 +789,7 @@ private: Instruction *MatchBSwap(BinaryOperator &I); bool SimplifyStoreAtEndOfBlock(StoreInst &SI); - Instruction * - SimplifyElementUnorderedAtomicMemCpy(ElementUnorderedAtomicMemCpyInst *AMI); + Instruction *SimplifyElementUnorderedAtomicMemCpy(AtomicMemCpyInst *AMI); Instruction *SimplifyMemTransfer(MemIntrinsic *MI); Instruction *SimplifyMemSet(MemSetInst *MI); @@ -753,8 +801,8 @@ private: Value *Descale(Value *Val, APInt Scale, bool &NoSignedWrap); }; -} // end namespace llvm. +} // end namespace llvm #undef DEBUG_TYPE -#endif +#endif // LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEINTERNAL_H diff --git a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index 451036545741..d4f06e18b957 100644 --- a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -18,13 +18,14 @@ #include "llvm/Analysis/Loads.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/DebugInfo.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; +using namespace PatternMatch; #define DEBUG_TYPE "instcombine" @@ -561,6 +562,28 @@ static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value return NewStore; } +/// Returns true if instruction represent minmax pattern like: +/// select ((cmp load V1, load V2), V1, V2). +static bool isMinMaxWithLoads(Value *V) { + assert(V->getType()->isPointerTy() && "Expected pointer type."); + // Ignore possible ty* to ixx* bitcast. + V = peekThroughBitcast(V); + // Check that select is select ((cmp load V1, load V2), V1, V2) - minmax + // pattern. + CmpInst::Predicate Pred; + Instruction *L1; + Instruction *L2; + Value *LHS; + Value *RHS; + if (!match(V, m_Select(m_Cmp(Pred, m_Instruction(L1), m_Instruction(L2)), + m_Value(LHS), m_Value(RHS)))) + return false; + return (match(L1, m_Load(m_Specific(LHS))) && + match(L2, m_Load(m_Specific(RHS)))) || + (match(L1, m_Load(m_Specific(RHS))) && + match(L2, m_Load(m_Specific(LHS)))); +} + /// \brief Combine loads to match the type of their uses' value after looking /// through intervening bitcasts. /// @@ -598,10 +621,14 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { // integers instead of any other type. We only do this when the loaded type // is sized and has a size exactly the same as its store size and the store // size is a legal integer type. + // Do not perform canonicalization if minmax pattern is found (to avoid + // infinite loop). if (!Ty->isIntegerTy() && Ty->isSized() && DL.isLegalInteger(DL.getTypeStoreSizeInBits(Ty)) && DL.getTypeStoreSizeInBits(Ty) == DL.getTypeSizeInBits(Ty) && - !DL.isNonIntegralPointerType(Ty)) { + !DL.isNonIntegralPointerType(Ty) && + !isMinMaxWithLoads( + peekThroughBitcast(LI.getPointerOperand(), /*OneUseOnly=*/true))) { if (all_of(LI.users(), [&LI](User *U) { auto *SI = dyn_cast<StoreInst>(U); return SI && SI->getPointerOperand() != &LI && @@ -931,6 +958,16 @@ static Instruction *replaceGEPIdxWithZero(InstCombiner &IC, Value *Ptr, return nullptr; } +static bool canSimplifyNullStoreOrGEP(StoreInst &SI) { + if (SI.getPointerAddressSpace() != 0) + return false; + + auto *Ptr = SI.getPointerOperand(); + if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Ptr)) + Ptr = GEPI->getOperand(0); + return isa<ConstantPointerNull>(Ptr); +} + static bool canSimplifyNullLoadOrGEP(LoadInst &LI, Value *Op) { if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Op)) { const Value *GEPI0 = GEPI->getOperand(0); @@ -1298,6 +1335,46 @@ static bool equivalentAddressValues(Value *A, Value *B) { return false; } +/// Converts store (bitcast (load (bitcast (select ...)))) to +/// store (load (select ...)), where select is minmax: +/// select ((cmp load V1, load V2), V1, V2). +static bool removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC, + StoreInst &SI) { + // bitcast? + if (!match(SI.getPointerOperand(), m_BitCast(m_Value()))) + return false; + // load? integer? + Value *LoadAddr; + if (!match(SI.getValueOperand(), m_Load(m_BitCast(m_Value(LoadAddr))))) + return false; + auto *LI = cast<LoadInst>(SI.getValueOperand()); + if (!LI->getType()->isIntegerTy()) + return false; + if (!isMinMaxWithLoads(LoadAddr)) + return false; + + if (!all_of(LI->users(), [LI, LoadAddr](User *U) { + auto *SI = dyn_cast<StoreInst>(U); + return SI && SI->getPointerOperand() != LI && + peekThroughBitcast(SI->getPointerOperand()) != LoadAddr && + !SI->getPointerOperand()->isSwiftError(); + })) + return false; + + IC.Builder.SetInsertPoint(LI); + LoadInst *NewLI = combineLoadToNewType( + IC, *LI, LoadAddr->getType()->getPointerElementType()); + // Replace all the stores with stores of the newly loaded value. + for (auto *UI : LI->users()) { + auto *USI = cast<StoreInst>(UI); + IC.Builder.SetInsertPoint(USI); + combineStoreToNewValue(IC, *USI, NewLI); + } + IC.replaceInstUsesWith(*LI, UndefValue::get(LI->getType())); + IC.eraseInstFromFunction(*LI); + return true; +} + Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { Value *Val = SI.getOperand(0); Value *Ptr = SI.getOperand(1); @@ -1322,6 +1399,9 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { if (unpackStoreToAggregate(*this, SI)) return eraseInstFromFunction(SI); + if (removeBitcastsFromLoadStoreOnMinMax(*this, SI)) + return eraseInstFromFunction(SI); + // Replace GEP indices if possible. if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI)) { Worklist.Add(NewGEPI); @@ -1392,7 +1472,8 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { } // store X, null -> turns into 'unreachable' in SimplifyCFG - if (isa<ConstantPointerNull>(Ptr) && SI.getPointerAddressSpace() == 0) { + // store X, GEP(null, Y) -> turns into 'unreachable' in SimplifyCFG + if (canSimplifyNullStoreOrGEP(SI)) { if (!isa<UndefValue>(Val)) { SI.setOperand(0, UndefValue::get(Val->getType())); if (Instruction *U = dyn_cast<Instruction>(Val)) @@ -1544,8 +1625,7 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { SI.getSyncScopeID()); InsertNewInstBefore(NewSI, *BBI); // The debug locations of the original instructions might differ; merge them. - NewSI->setDebugLoc(DILocation::getMergedLocation(SI.getDebugLoc(), - OtherStore->getDebugLoc())); + NewSI->applyMergedLocation(SI.getDebugLoc(), OtherStore->getDebugLoc()); // If the two stores had AA tags, merge them. AAMDNodes AATags; diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index e3a50220f94e..87666360c1a0 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -13,15 +13,36 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <utility> + using namespace llvm; using namespace PatternMatch; #define DEBUG_TYPE "instcombine" - /// The specific integer value is used in a context where it is known to be /// non-zero. If this allows us to simplify the computation, do so and return /// the new operand, otherwise return null. @@ -73,7 +94,6 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, return MadeChange ? V : nullptr; } - /// True if the multiply can not be expressed in an int this size. static bool MultiplyOverflows(const APInt &C1, const APInt &C2, APInt &Product, bool IsSigned) { @@ -467,7 +487,7 @@ static void detectLog2OfHalf(Value *&Op, Value *&Y, IntrinsicInst *&Log2) { IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op); if (!II) return; - if (II->getIntrinsicID() != Intrinsic::log2 || !II->hasUnsafeAlgebra()) + if (II->getIntrinsicID() != Intrinsic::log2 || !II->isFast()) return; Log2 = II; @@ -478,7 +498,8 @@ static void detectLog2OfHalf(Value *&Op, Value *&Y, IntrinsicInst *&Log2) { Instruction *I = dyn_cast<Instruction>(OpLog2Of); if (!I) return; - if (I->getOpcode() != Instruction::FMul || !I->hasUnsafeAlgebra()) + + if (I->getOpcode() != Instruction::FMul || !I->isFast()) return; if (match(I->getOperand(0), m_SpecificFP(0.5))) @@ -540,7 +561,6 @@ static bool isFMulOrFDivWithConstant(Value *V) { /// This function is to simplify "FMulOrDiv * C" and returns the /// resulting expression. Note that this function could return NULL in /// case the constants cannot be folded into a normal floating-point. -/// Value *InstCombiner::foldFMulConst(Instruction *FMulOrDiv, Constant *C, Instruction *InsertBefore) { assert(isFMulOrFDivWithConstant(FMulOrDiv) && "V is invalid"); @@ -582,7 +602,7 @@ Value *InstCombiner::foldFMulConst(Instruction *FMulOrDiv, Constant *C, } if (R) { - R->setHasUnsafeAlgebra(true); + R->setFast(true); InsertNewInstWith(R, *InsertBefore); } @@ -603,7 +623,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - bool AllowReassociate = I.hasUnsafeAlgebra(); + bool AllowReassociate = I.isFast(); // Simplify mul instructions with a constant RHS. if (isa<Constant>(Op1)) { @@ -736,6 +756,10 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { } } + // Handle specials cases for FMul with selects feeding the operation + if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) + return replaceInstUsesWith(I, V); + // (X*Y) * X => (X*X) * Y where Y != X // The purpose is two-fold: // 1) to form a power expression (of X). @@ -743,7 +767,6 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { // latency of the instruction Y is amortized by the expression of X*X, // and therefore Y is in a "less critical" position compared to what it // was before the transformation. - // if (AllowReassociate) { Value *Opnd0_0, *Opnd0_1; if (Opnd0->hasOneUse() && @@ -774,24 +797,23 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { return Changed ? &I : nullptr; } -/// Try to fold a divide or remainder of a select instruction. -bool InstCombiner::SimplifyDivRemOfSelect(BinaryOperator &I) { - SelectInst *SI = cast<SelectInst>(I.getOperand(1)); - - // div/rem X, (Cond ? 0 : Y) -> div/rem X, Y - int NonNullOperand = -1; - if (Constant *ST = dyn_cast<Constant>(SI->getOperand(1))) - if (ST->isNullValue()) - NonNullOperand = 2; - // div/rem X, (Cond ? Y : 0) -> div/rem X, Y - if (Constant *ST = dyn_cast<Constant>(SI->getOperand(2))) - if (ST->isNullValue()) - NonNullOperand = 1; - - if (NonNullOperand == -1) +/// Fold a divide or remainder with a select instruction divisor when one of the +/// select operands is zero. In that case, we can use the other select operand +/// because div/rem by zero is undefined. +bool InstCombiner::simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I) { + SelectInst *SI = dyn_cast<SelectInst>(I.getOperand(1)); + if (!SI) return false; - Value *SelectCond = SI->getOperand(0); + int NonNullOperand; + if (match(SI->getTrueValue(), m_Zero())) + // div/rem X, (Cond ? 0 : Y) -> div/rem X, Y + NonNullOperand = 2; + else if (match(SI->getFalseValue(), m_Zero())) + // div/rem X, (Cond ? Y : 0) -> div/rem X, Y + NonNullOperand = 1; + else + return false; // Change the div/rem to use 'Y' instead of the select. I.setOperand(1, SI->getOperand(NonNullOperand)); @@ -804,12 +826,13 @@ bool InstCombiner::SimplifyDivRemOfSelect(BinaryOperator &I) { // If the select and condition only have a single use, don't bother with this, // early exit. + Value *SelectCond = SI->getCondition(); if (SI->use_empty() && SelectCond->hasOneUse()) return true; // Scan the current block backward, looking for other uses of SI. BasicBlock::iterator BBI = I.getIterator(), BBFront = I.getParent()->begin(); - + Type *CondTy = SelectCond->getType(); while (BBI != BBFront) { --BBI; // If we found a call to a function, we can't assume it will return, so @@ -824,7 +847,8 @@ bool InstCombiner::SimplifyDivRemOfSelect(BinaryOperator &I) { *I = SI->getOperand(NonNullOperand); Worklist.Add(&*BBI); } else if (*I == SelectCond) { - *I = Builder.getInt1(NonNullOperand == 1); + *I = NonNullOperand == 1 ? ConstantInt::getTrue(CondTy) + : ConstantInt::getFalse(CondTy); Worklist.Add(&*BBI); } } @@ -843,7 +867,6 @@ bool InstCombiner::SimplifyDivRemOfSelect(BinaryOperator &I) { return true; } - /// This function implements the transforms common to both integer division /// instructions (udiv and sdiv). It is called by the visitors to those integer /// division instructions. @@ -859,7 +882,7 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { // Handle cases involving: [su]div X, (select Cond, Y, Z) // This does not apply for fdiv. - if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I)) + if (simplifyDivRemOfSelectWithZeroOp(I)) return &I; if (Instruction *LHS = dyn_cast<Instruction>(Op0)) { @@ -969,38 +992,29 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { return nullptr; } -/// dyn_castZExtVal - Checks if V is a zext or constant that can -/// be truncated to Ty without losing bits. -static Value *dyn_castZExtVal(Value *V, Type *Ty) { - if (ZExtInst *Z = dyn_cast<ZExtInst>(V)) { - if (Z->getSrcTy() == Ty) - return Z->getOperand(0); - } else if (ConstantInt *C = dyn_cast<ConstantInt>(V)) { - if (C->getValue().getActiveBits() <= cast<IntegerType>(Ty)->getBitWidth()) - return ConstantExpr::getTrunc(C, Ty); - } - return nullptr; -} +static const unsigned MaxDepth = 6; namespace { -const unsigned MaxDepth = 6; -typedef Instruction *(*FoldUDivOperandCb)(Value *Op0, Value *Op1, - const BinaryOperator &I, - InstCombiner &IC); + +using FoldUDivOperandCb = Instruction *(*)(Value *Op0, Value *Op1, + const BinaryOperator &I, + InstCombiner &IC); /// \brief Used to maintain state for visitUDivOperand(). struct UDivFoldAction { - FoldUDivOperandCb FoldAction; ///< Informs visitUDiv() how to fold this - ///< operand. This can be zero if this action - ///< joins two actions together. + /// Informs visitUDiv() how to fold this operand. This can be zero if this + /// action joins two actions together. + FoldUDivOperandCb FoldAction; + + /// Which operand to fold. + Value *OperandToFold; - Value *OperandToFold; ///< Which operand to fold. union { - Instruction *FoldResult; ///< The instruction returned when FoldAction is - ///< invoked. + /// The instruction returned when FoldAction is invoked. + Instruction *FoldResult; - size_t SelectLHSIdx; ///< Stores the LHS action index if this action - ///< joins two actions together. + /// Stores the LHS action index if this action joins two actions together. + size_t SelectLHSIdx; }; UDivFoldAction(FoldUDivOperandCb FA, Value *InputOperand) @@ -1008,7 +1022,8 @@ struct UDivFoldAction { UDivFoldAction(FoldUDivOperandCb FA, Value *InputOperand, size_t SLHS) : FoldAction(FA), OperandToFold(InputOperand), SelectLHSIdx(SLHS) {} }; -} + +} // end anonymous namespace // X udiv 2^C -> X >> C static Instruction *foldUDivPow2Cst(Value *Op0, Value *Op1, @@ -1095,6 +1110,43 @@ static size_t visitUDivOperand(Value *Op0, Value *Op1, const BinaryOperator &I, return 0; } +/// If we have zero-extended operands of an unsigned div or rem, we may be able +/// to narrow the operation (sink the zext below the math). +static Instruction *narrowUDivURem(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Instruction::BinaryOps Opcode = I.getOpcode(); + Value *N = I.getOperand(0); + Value *D = I.getOperand(1); + Type *Ty = I.getType(); + Value *X, *Y; + if (match(N, m_ZExt(m_Value(X))) && match(D, m_ZExt(m_Value(Y))) && + X->getType() == Y->getType() && (N->hasOneUse() || D->hasOneUse())) { + // udiv (zext X), (zext Y) --> zext (udiv X, Y) + // urem (zext X), (zext Y) --> zext (urem X, Y) + Value *NarrowOp = Builder.CreateBinOp(Opcode, X, Y); + return new ZExtInst(NarrowOp, Ty); + } + + Constant *C; + if ((match(N, m_OneUse(m_ZExt(m_Value(X)))) && match(D, m_Constant(C))) || + (match(D, m_OneUse(m_ZExt(m_Value(X)))) && match(N, m_Constant(C)))) { + // If the constant is the same in the smaller type, use the narrow version. + Constant *TruncC = ConstantExpr::getTrunc(C, X->getType()); + if (ConstantExpr::getZExt(TruncC, Ty) != C) + return nullptr; + + // udiv (zext X), C --> zext (udiv X, C') + // urem (zext X), C --> zext (urem X, C') + // udiv C, (zext X) --> zext (udiv C', X) + // urem C, (zext X) --> zext (urem C', X) + Value *NarrowOp = isa<Constant>(D) ? Builder.CreateBinOp(Opcode, X, TruncC) + : Builder.CreateBinOp(Opcode, TruncC, X); + return new ZExtInst(NarrowOp, Ty); + } + + return nullptr; +} + Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -1127,12 +1179,8 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { } } - // (zext A) udiv (zext B) --> zext (A udiv B) - if (ZExtInst *ZOp0 = dyn_cast<ZExtInst>(Op0)) - if (Value *ZOp1 = dyn_castZExtVal(Op1, ZOp0->getSrcTy())) - return new ZExtInst( - Builder.CreateUDiv(ZOp0->getOperand(0), ZOp1, "div", I.isExact()), - I.getType()); + if (Instruction *NarrowDiv = narrowUDivURem(I, Builder)) + return NarrowDiv; // (LHS udiv (select (select (...)))) -> (LHS >> (select (select (...)))) SmallVector<UDivFoldAction, 6> UDivActions; @@ -1255,8 +1303,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { /// 1) 1/C is exact, or /// 2) reciprocal is allowed. /// If the conversion was successful, the simplified expression "X * 1/C" is -/// returned; otherwise, NULL is returned. -/// +/// returned; otherwise, nullptr is returned. static Instruction *CvtFDivConstToReciprocal(Value *Dividend, Constant *Divisor, bool AllowReciprocal) { if (!isa<ConstantFP>(Divisor)) // TODO: handle vectors. @@ -1295,7 +1342,7 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { if (Instruction *R = FoldOpIntoSelect(I, SI)) return R; - bool AllowReassociate = I.hasUnsafeAlgebra(); + bool AllowReassociate = I.isFast(); bool AllowReciprocal = I.hasAllowReciprocal(); if (Constant *Op1C = dyn_cast<Constant>(Op1)) { @@ -1317,7 +1364,6 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { Res = BinaryOperator::CreateFMul(X, C); } else if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) { // (X/C1)/C2 => X /(C2*C1) [=> X * 1/(C2*C1) if reciprocal is allowed] - // Constant *C = ConstantExpr::getFMul(C1, C2); if (isNormalFp(C)) { Res = CvtFDivConstToReciprocal(X, C, AllowReciprocal); @@ -1375,7 +1421,6 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { if (Op0->hasOneUse() && match(Op0, m_FDiv(m_Value(X), m_Value(Y)))) { // (X/Y) / Z => X / (Y*Z) - // if (!isa<Constant>(Y) || !isa<Constant>(Op1)) { NewInst = Builder.CreateFMul(Y, Op1); if (Instruction *RI = dyn_cast<Instruction>(NewInst)) { @@ -1387,7 +1432,6 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { } } else if (Op1->hasOneUse() && match(Op1, m_FDiv(m_Value(X), m_Value(Y)))) { // Z / (X/Y) => Z*Y / X - // if (!isa<Constant>(Y) || !isa<Constant>(Op0)) { NewInst = Builder.CreateFMul(Op0, Y); if (Instruction *RI = dyn_cast<Instruction>(NewInst)) { @@ -1434,7 +1478,7 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { } // Handle cases involving: rem X, (select Cond, Y, Z) - if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I)) + if (simplifyDivRemOfSelectWithZeroOp(I)) return &I; if (isa<Constant>(Op1)) { @@ -1443,7 +1487,6 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { if (Instruction *R = FoldOpIntoSelect(I, SI)) return R; } else if (auto *PN = dyn_cast<PHINode>(Op0I)) { - using namespace llvm::PatternMatch; const APInt *Op1Int; if (match(Op1, m_APInt(Op1Int)) && !Op1Int->isMinValue() && (I.getOpcode() == Instruction::URem || @@ -1477,11 +1520,8 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { if (Instruction *common = commonIRemTransforms(I)) return common; - // (zext A) urem (zext B) --> zext (A urem B) - if (ZExtInst *ZOp0 = dyn_cast<ZExtInst>(Op0)) - if (Value *ZOp1 = dyn_castZExtVal(Op1, ZOp0->getSrcTy())) - return new ZExtInst(Builder.CreateURem(ZOp0->getOperand(0), ZOp1), - I.getType()); + if (Instruction *NarrowRem = narrowUDivURem(I, Builder)) + return NarrowRem; // X urem Y -> X and Y-1, where Y is a power of 2, if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) { @@ -1592,7 +1632,7 @@ Instruction *InstCombiner::visitFRem(BinaryOperator &I) { return replaceInstUsesWith(I, V); // Handle cases involving: rem X, (select Cond, Y, Z) - if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I)) + if (simplifyDivRemOfSelectWithZeroOp(I)) return &I; return nullptr; diff --git a/lib/Transforms/InstCombine/InstCombinePHI.cpp b/lib/Transforms/InstCombine/InstCombinePHI.cpp index 0011412c2bf4..7ee018dbc49b 100644 --- a/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -16,7 +16,6 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/DebugInfo.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -27,16 +26,249 @@ using namespace llvm::PatternMatch; /// The PHI arguments will be folded into a single operation with a PHI node /// as input. The debug location of the single operation will be the merged /// locations of the original PHI node arguments. -DebugLoc InstCombiner::PHIArgMergedDebugLoc(PHINode &PN) { +void InstCombiner::PHIArgMergedDebugLoc(Instruction *Inst, PHINode &PN) { auto *FirstInst = cast<Instruction>(PN.getIncomingValue(0)); - const DILocation *Loc = FirstInst->getDebugLoc(); + Inst->setDebugLoc(FirstInst->getDebugLoc()); + // We do not expect a CallInst here, otherwise, N-way merging of DebugLoc + // will be inefficient. + assert(!isa<CallInst>(Inst)); for (unsigned i = 1; i != PN.getNumIncomingValues(); ++i) { auto *I = cast<Instruction>(PN.getIncomingValue(i)); - Loc = DILocation::getMergedLocation(Loc, I->getDebugLoc()); + Inst->applyMergedLocation(Inst->getDebugLoc(), I->getDebugLoc()); + } +} + +// Replace Integer typed PHI PN if the PHI's value is used as a pointer value. +// If there is an existing pointer typed PHI that produces the same value as PN, +// replace PN and the IntToPtr operation with it. Otherwise, synthesize a new +// PHI node: +// +// Case-1: +// bb1: +// int_init = PtrToInt(ptr_init) +// br label %bb2 +// bb2: +// int_val = PHI([int_init, %bb1], [int_val_inc, %bb2] +// ptr_val = PHI([ptr_init, %bb1], [ptr_val_inc, %bb2] +// ptr_val2 = IntToPtr(int_val) +// ... +// use(ptr_val2) +// ptr_val_inc = ... +// inc_val_inc = PtrToInt(ptr_val_inc) +// +// ==> +// bb1: +// br label %bb2 +// bb2: +// ptr_val = PHI([ptr_init, %bb1], [ptr_val_inc, %bb2] +// ... +// use(ptr_val) +// ptr_val_inc = ... +// +// Case-2: +// bb1: +// int_ptr = BitCast(ptr_ptr) +// int_init = Load(int_ptr) +// br label %bb2 +// bb2: +// int_val = PHI([int_init, %bb1], [int_val_inc, %bb2] +// ptr_val2 = IntToPtr(int_val) +// ... +// use(ptr_val2) +// ptr_val_inc = ... +// inc_val_inc = PtrToInt(ptr_val_inc) +// ==> +// bb1: +// ptr_init = Load(ptr_ptr) +// br label %bb2 +// bb2: +// ptr_val = PHI([ptr_init, %bb1], [ptr_val_inc, %bb2] +// ... +// use(ptr_val) +// ptr_val_inc = ... +// ... +// +Instruction *InstCombiner::FoldIntegerTypedPHI(PHINode &PN) { + if (!PN.getType()->isIntegerTy()) + return nullptr; + if (!PN.hasOneUse()) + return nullptr; + + auto *IntToPtr = dyn_cast<IntToPtrInst>(PN.user_back()); + if (!IntToPtr) + return nullptr; + + // Check if the pointer is actually used as pointer: + auto HasPointerUse = [](Instruction *IIP) { + for (User *U : IIP->users()) { + Value *Ptr = nullptr; + if (LoadInst *LoadI = dyn_cast<LoadInst>(U)) { + Ptr = LoadI->getPointerOperand(); + } else if (StoreInst *SI = dyn_cast<StoreInst>(U)) { + Ptr = SI->getPointerOperand(); + } else if (GetElementPtrInst *GI = dyn_cast<GetElementPtrInst>(U)) { + Ptr = GI->getPointerOperand(); + } + + if (Ptr && Ptr == IIP) + return true; + } + return false; + }; + + if (!HasPointerUse(IntToPtr)) + return nullptr; + + if (DL.getPointerSizeInBits(IntToPtr->getAddressSpace()) != + DL.getTypeSizeInBits(IntToPtr->getOperand(0)->getType())) + return nullptr; + + SmallVector<Value *, 4> AvailablePtrVals; + for (unsigned i = 0; i != PN.getNumIncomingValues(); ++i) { + Value *Arg = PN.getIncomingValue(i); + + // First look backward: + if (auto *PI = dyn_cast<PtrToIntInst>(Arg)) { + AvailablePtrVals.emplace_back(PI->getOperand(0)); + continue; + } + + // Next look forward: + Value *ArgIntToPtr = nullptr; + for (User *U : Arg->users()) { + if (isa<IntToPtrInst>(U) && U->getType() == IntToPtr->getType() && + (DT.dominates(cast<Instruction>(U), PN.getIncomingBlock(i)) || + cast<Instruction>(U)->getParent() == PN.getIncomingBlock(i))) { + ArgIntToPtr = U; + break; + } + } + + if (ArgIntToPtr) { + AvailablePtrVals.emplace_back(ArgIntToPtr); + continue; + } + + // If Arg is defined by a PHI, allow it. This will also create + // more opportunities iteratively. + if (isa<PHINode>(Arg)) { + AvailablePtrVals.emplace_back(Arg); + continue; + } + + // For a single use integer load: + auto *LoadI = dyn_cast<LoadInst>(Arg); + if (!LoadI) + return nullptr; + + if (!LoadI->hasOneUse()) + return nullptr; + + // Push the integer typed Load instruction into the available + // value set, and fix it up later when the pointer typed PHI + // is synthesized. + AvailablePtrVals.emplace_back(LoadI); + } + + // Now search for a matching PHI + auto *BB = PN.getParent(); + assert(AvailablePtrVals.size() == PN.getNumIncomingValues() && + "Not enough available ptr typed incoming values"); + PHINode *MatchingPtrPHI = nullptr; + for (auto II = BB->begin(), EI = BasicBlock::iterator(BB->getFirstNonPHI()); + II != EI; II++) { + PHINode *PtrPHI = dyn_cast<PHINode>(II); + if (!PtrPHI || PtrPHI == &PN || PtrPHI->getType() != IntToPtr->getType()) + continue; + MatchingPtrPHI = PtrPHI; + for (unsigned i = 0; i != PtrPHI->getNumIncomingValues(); ++i) { + if (AvailablePtrVals[i] != + PtrPHI->getIncomingValueForBlock(PN.getIncomingBlock(i))) { + MatchingPtrPHI = nullptr; + break; + } + } + + if (MatchingPtrPHI) + break; + } + + if (MatchingPtrPHI) { + assert(MatchingPtrPHI->getType() == IntToPtr->getType() && + "Phi's Type does not match with IntToPtr"); + // The PtrToCast + IntToPtr will be simplified later + return CastInst::CreateBitOrPointerCast(MatchingPtrPHI, + IntToPtr->getOperand(0)->getType()); } - return Loc; + // If it requires a conversion for every PHI operand, do not do it. + if (std::all_of(AvailablePtrVals.begin(), AvailablePtrVals.end(), + [&](Value *V) { + return (V->getType() != IntToPtr->getType()) || + isa<IntToPtrInst>(V); + })) + return nullptr; + + // If any of the operand that requires casting is a terminator + // instruction, do not do it. + if (std::any_of(AvailablePtrVals.begin(), AvailablePtrVals.end(), + [&](Value *V) { + return (V->getType() != IntToPtr->getType()) && + isa<TerminatorInst>(V); + })) + return nullptr; + + PHINode *NewPtrPHI = PHINode::Create( + IntToPtr->getType(), PN.getNumIncomingValues(), PN.getName() + ".ptr"); + + InsertNewInstBefore(NewPtrPHI, PN); + SmallDenseMap<Value *, Instruction *> Casts; + for (unsigned i = 0; i != PN.getNumIncomingValues(); ++i) { + auto *IncomingBB = PN.getIncomingBlock(i); + auto *IncomingVal = AvailablePtrVals[i]; + + if (IncomingVal->getType() == IntToPtr->getType()) { + NewPtrPHI->addIncoming(IncomingVal, IncomingBB); + continue; + } + +#ifndef NDEBUG + LoadInst *LoadI = dyn_cast<LoadInst>(IncomingVal); + assert((isa<PHINode>(IncomingVal) || + IncomingVal->getType()->isPointerTy() || + (LoadI && LoadI->hasOneUse())) && + "Can not replace LoadInst with multiple uses"); +#endif + // Need to insert a BitCast. + // For an integer Load instruction with a single use, the load + IntToPtr + // cast will be simplified into a pointer load: + // %v = load i64, i64* %a.ip, align 8 + // %v.cast = inttoptr i64 %v to float ** + // ==> + // %v.ptrp = bitcast i64 * %a.ip to float ** + // %v.cast = load float *, float ** %v.ptrp, align 8 + Instruction *&CI = Casts[IncomingVal]; + if (!CI) { + CI = CastInst::CreateBitOrPointerCast(IncomingVal, IntToPtr->getType(), + IncomingVal->getName() + ".ptr"); + if (auto *IncomingI = dyn_cast<Instruction>(IncomingVal)) { + BasicBlock::iterator InsertPos(IncomingI); + InsertPos++; + if (isa<PHINode>(IncomingI)) + InsertPos = IncomingI->getParent()->getFirstInsertionPt(); + InsertNewInstBefore(CI, *InsertPos); + } else { + auto *InsertBB = &IncomingBB->getParent()->getEntryBlock(); + InsertNewInstBefore(CI, *InsertBB->getFirstInsertionPt()); + } + } + NewPtrPHI->addIncoming(CI, IncomingBB); + } + + // The PtrToCast + IntToPtr will be simplified later + return CastInst::CreateBitOrPointerCast(NewPtrPHI, + IntToPtr->getOperand(0)->getType()); } /// If we have something like phi [add (a,b), add(a,c)] and if a/b/c and the @@ -117,7 +349,7 @@ Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { if (CmpInst *CIOp = dyn_cast<CmpInst>(FirstInst)) { CmpInst *NewCI = CmpInst::Create(CIOp->getOpcode(), CIOp->getPredicate(), LHSVal, RHSVal); - NewCI->setDebugLoc(PHIArgMergedDebugLoc(PN)); + PHIArgMergedDebugLoc(NewCI, PN); return NewCI; } @@ -130,7 +362,7 @@ Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) NewBinOp->andIRFlags(PN.getIncomingValue(i)); - NewBinOp->setDebugLoc(PHIArgMergedDebugLoc(PN)); + PHIArgMergedDebugLoc(NewBinOp, PN); return NewBinOp; } @@ -239,7 +471,7 @@ Instruction *InstCombiner::FoldPHIArgGEPIntoPHI(PHINode &PN) { GetElementPtrInst::Create(FirstInst->getSourceElementType(), Base, makeArrayRef(FixedOperands).slice(1)); if (AllInBounds) NewGEP->setIsInBounds(); - NewGEP->setDebugLoc(PHIArgMergedDebugLoc(PN)); + PHIArgMergedDebugLoc(NewGEP, PN); return NewGEP; } @@ -399,7 +631,7 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) { for (Value *IncValue : PN.incoming_values()) cast<LoadInst>(IncValue)->setVolatile(false); - NewLI->setDebugLoc(PHIArgMergedDebugLoc(PN)); + PHIArgMergedDebugLoc(NewLI, PN); return NewLI; } @@ -565,7 +797,7 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { if (CastInst *FirstCI = dyn_cast<CastInst>(FirstInst)) { CastInst *NewCI = CastInst::Create(FirstCI->getOpcode(), PhiVal, PN.getType()); - NewCI->setDebugLoc(PHIArgMergedDebugLoc(PN)); + PHIArgMergedDebugLoc(NewCI, PN); return NewCI; } @@ -576,14 +808,14 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) BinOp->andIRFlags(PN.getIncomingValue(i)); - BinOp->setDebugLoc(PHIArgMergedDebugLoc(PN)); + PHIArgMergedDebugLoc(BinOp, PN); return BinOp; } CmpInst *CIOp = cast<CmpInst>(FirstInst); CmpInst *NewCI = CmpInst::Create(CIOp->getOpcode(), CIOp->getPredicate(), PhiVal, ConstantOp); - NewCI->setDebugLoc(PHIArgMergedDebugLoc(PN)); + PHIArgMergedDebugLoc(NewCI, PN); return NewCI; } @@ -902,6 +1134,9 @@ Instruction *InstCombiner::visitPHINode(PHINode &PN) { // this PHI only has a single use (a PHI), and if that PHI only has one use (a // PHI)... break the cycle. if (PN.hasOneUse()) { + if (Instruction *Result = FoldIntegerTypedPHI(PN)) + return Result; + Instruction *PHIUser = cast<Instruction>(PN.user_back()); if (PHINode *PU = dyn_cast<PHINode>(PHIUser)) { SmallPtrSet<PHINode*, 16> PotentiallyDeadPHIs; diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 4eebe8255998..6f26f7f5cd19 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -12,12 +12,36 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" -#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" +#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include <cassert> +#include <utility> + using namespace llvm; using namespace PatternMatch; @@ -69,6 +93,111 @@ static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy &Builder, return Builder.CreateSelect(Builder.CreateICmp(Pred, A, B), A, B); } +/// If one of the constants is zero (we know they can't both be) and we have an +/// icmp instruction with zero, and we have an 'and' with the non-constant value +/// and a power of two we can turn the select into a shift on the result of the +/// 'and'. +/// This folds: +/// select (icmp eq (and X, C1)), C2, C3 +/// iff C1 is a power 2 and the difference between C2 and C3 is a power of 2. +/// To something like: +/// (shr (and (X, C1)), (log2(C1) - log2(C2-C3))) + C3 +/// Or: +/// (shl (and (X, C1)), (log2(C2-C3) - log2(C1))) + C3 +/// With some variations depending if C3 is larger than C2, or the shift +/// isn't needed, or the bit widths don't match. +static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, + APInt TrueVal, APInt FalseVal, + InstCombiner::BuilderTy &Builder) { + assert(SelType->isIntOrIntVectorTy() && "Not an integer select?"); + + // If this is a vector select, we need a vector compare. + if (SelType->isVectorTy() != IC->getType()->isVectorTy()) + return nullptr; + + Value *V; + APInt AndMask; + bool CreateAnd = false; + ICmpInst::Predicate Pred = IC->getPredicate(); + if (ICmpInst::isEquality(Pred)) { + if (!match(IC->getOperand(1), m_Zero())) + return nullptr; + + V = IC->getOperand(0); + + const APInt *AndRHS; + if (!match(V, m_And(m_Value(), m_Power2(AndRHS)))) + return nullptr; + + AndMask = *AndRHS; + } else if (decomposeBitTestICmp(IC->getOperand(0), IC->getOperand(1), + Pred, V, AndMask)) { + assert(ICmpInst::isEquality(Pred) && "Not equality test?"); + + if (!AndMask.isPowerOf2()) + return nullptr; + + CreateAnd = true; + } else { + return nullptr; + } + + // If both select arms are non-zero see if we have a select of the form + // 'x ? 2^n + C : C'. Then we can offset both arms by C, use the logic + // for 'x ? 2^n : 0' and fix the thing up at the end. + APInt Offset(TrueVal.getBitWidth(), 0); + if (!TrueVal.isNullValue() && !FalseVal.isNullValue()) { + if ((TrueVal - FalseVal).isPowerOf2()) + Offset = FalseVal; + else if ((FalseVal - TrueVal).isPowerOf2()) + Offset = TrueVal; + else + return nullptr; + + // Adjust TrueVal and FalseVal to the offset. + TrueVal -= Offset; + FalseVal -= Offset; + } + + // Make sure one of the select arms is a power of 2. + if (!TrueVal.isPowerOf2() && !FalseVal.isPowerOf2()) + return nullptr; + + // Determine which shift is needed to transform result of the 'and' into the + // desired result. + const APInt &ValC = !TrueVal.isNullValue() ? TrueVal : FalseVal; + unsigned ValZeros = ValC.logBase2(); + unsigned AndZeros = AndMask.logBase2(); + + if (CreateAnd) { + // Insert the AND instruction on the input to the truncate. + V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), AndMask)); + } + + // If types don't match we can still convert the select by introducing a zext + // or a trunc of the 'and'. + if (ValZeros > AndZeros) { + V = Builder.CreateZExtOrTrunc(V, SelType); + V = Builder.CreateShl(V, ValZeros - AndZeros); + } else if (ValZeros < AndZeros) { + V = Builder.CreateLShr(V, AndZeros - ValZeros); + V = Builder.CreateZExtOrTrunc(V, SelType); + } else + V = Builder.CreateZExtOrTrunc(V, SelType); + + // Okay, now we know that everything is set up, we just don't know whether we + // have a icmp_ne or icmp_eq and whether the true or false val is the zero. + bool ShouldNotVal = !TrueVal.isNullValue(); + ShouldNotVal ^= Pred == ICmpInst::ICMP_NE; + if (ShouldNotVal) + V = Builder.CreateXor(V, ValC); + + // Apply an offset if needed. + if (!Offset.isNullValue()) + V = Builder.CreateAdd(V, ConstantInt::get(V->getType(), Offset)); + return V; +} + /// We want to turn code that looks like this: /// %C = or %A, %B /// %D = select %cond, %C, %A @@ -79,8 +208,7 @@ static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy &Builder, /// Assuming that the specified instruction is an operand to the select, return /// a bitmask indicating which operands of this instruction are foldable if they /// equal the other incoming value of the select. -/// -static unsigned getSelectFoldableOperands(Instruction *I) { +static unsigned getSelectFoldableOperands(BinaryOperator *I) { switch (I->getOpcode()) { case Instruction::Add: case Instruction::Mul: @@ -100,7 +228,7 @@ static unsigned getSelectFoldableOperands(Instruction *I) { /// For the same transformation as the previous function, return the identity /// constant that goes into the select. -static Constant *getSelectFoldableConstant(Instruction *I) { +static APInt getSelectFoldableConstant(BinaryOperator *I) { switch (I->getOpcode()) { default: llvm_unreachable("This cannot happen!"); case Instruction::Add: @@ -110,11 +238,11 @@ static Constant *getSelectFoldableConstant(Instruction *I) { case Instruction::Shl: case Instruction::LShr: case Instruction::AShr: - return Constant::getNullValue(I->getType()); + return APInt::getNullValue(I->getType()->getScalarSizeInBits()); case Instruction::And: - return Constant::getAllOnesValue(I->getType()); + return APInt::getAllOnesValue(I->getType()->getScalarSizeInBits()); case Instruction::Mul: - return ConstantInt::get(I->getType(), 1); + return APInt(I->getType()->getScalarSizeInBits(), 1); } } @@ -157,7 +285,6 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, if (TI->getOpcode() != Instruction::BitCast && (!TI->hasOneUse() || !FI->hasOneUse())) return nullptr; - } else if (!TI->hasOneUse() || !FI->hasOneUse()) { // TODO: The one-use restrictions for a scalar select could be eased if // the fold of a select in visitLoadInst() was enhanced to match a pattern @@ -218,17 +345,11 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, return BinaryOperator::Create(BO->getOpcode(), Op0, Op1); } -static bool isSelect01(Constant *C1, Constant *C2) { - ConstantInt *C1I = dyn_cast<ConstantInt>(C1); - if (!C1I) - return false; - ConstantInt *C2I = dyn_cast<ConstantInt>(C2); - if (!C2I) - return false; - if (!C1I->isZero() && !C2I->isZero()) // One side must be zero. +static bool isSelect01(const APInt &C1I, const APInt &C2I) { + if (!C1I.isNullValue() && !C2I.isNullValue()) // One side must be zero. return false; - return C1I->isOne() || C1I->isMinusOne() || - C2I->isOne() || C2I->isMinusOne(); + return C1I.isOneValue() || C1I.isAllOnesValue() || + C2I.isOneValue() || C2I.isAllOnesValue(); } /// Try to fold the select into one of the operands to allow further @@ -237,9 +358,8 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, Value *FalseVal) { // See the comment above GetSelectFoldableOperands for a description of the // transformation we are doing here. - if (Instruction *TVI = dyn_cast<Instruction>(TrueVal)) { - if (TVI->hasOneUse() && TVI->getNumOperands() == 2 && - !isa<Constant>(FalseVal)) { + if (auto *TVI = dyn_cast<BinaryOperator>(TrueVal)) { + if (TVI->hasOneUse() && !isa<Constant>(FalseVal)) { if (unsigned SFO = getSelectFoldableOperands(TVI)) { unsigned OpToFold = 0; if ((SFO & 1) && FalseVal == TVI->getOperand(0)) { @@ -249,17 +369,19 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, } if (OpToFold) { - Constant *C = getSelectFoldableConstant(TVI); + APInt CI = getSelectFoldableConstant(TVI); Value *OOp = TVI->getOperand(2-OpToFold); // Avoid creating select between 2 constants unless it's selecting // between 0, 1 and -1. - if (!isa<Constant>(OOp) || isSelect01(C, cast<Constant>(OOp))) { + const APInt *OOpC; + bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); + if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(CI, *OOpC))) { + Value *C = ConstantInt::get(OOp->getType(), CI); Value *NewSel = Builder.CreateSelect(SI.getCondition(), OOp, C); NewSel->takeName(TVI); - BinaryOperator *TVI_BO = cast<BinaryOperator>(TVI); - BinaryOperator *BO = BinaryOperator::Create(TVI_BO->getOpcode(), + BinaryOperator *BO = BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel); - BO->copyIRFlags(TVI_BO); + BO->copyIRFlags(TVI); return BO; } } @@ -267,9 +389,8 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, } } - if (Instruction *FVI = dyn_cast<Instruction>(FalseVal)) { - if (FVI->hasOneUse() && FVI->getNumOperands() == 2 && - !isa<Constant>(TrueVal)) { + if (auto *FVI = dyn_cast<BinaryOperator>(FalseVal)) { + if (FVI->hasOneUse() && !isa<Constant>(TrueVal)) { if (unsigned SFO = getSelectFoldableOperands(FVI)) { unsigned OpToFold = 0; if ((SFO & 1) && TrueVal == FVI->getOperand(0)) { @@ -279,17 +400,19 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, } if (OpToFold) { - Constant *C = getSelectFoldableConstant(FVI); + APInt CI = getSelectFoldableConstant(FVI); Value *OOp = FVI->getOperand(2-OpToFold); // Avoid creating select between 2 constants unless it's selecting // between 0, 1 and -1. - if (!isa<Constant>(OOp) || isSelect01(C, cast<Constant>(OOp))) { + const APInt *OOpC; + bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); + if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(CI, *OOpC))) { + Value *C = ConstantInt::get(OOp->getType(), CI); Value *NewSel = Builder.CreateSelect(SI.getCondition(), C, OOp); NewSel->takeName(FVI); - BinaryOperator *FVI_BO = cast<BinaryOperator>(FVI); - BinaryOperator *BO = BinaryOperator::Create(FVI_BO->getOpcode(), + BinaryOperator *BO = BinaryOperator::Create(FVI->getOpcode(), TrueVal, NewSel); - BO->copyIRFlags(FVI_BO); + BO->copyIRFlags(FVI); return BO; } } @@ -313,11 +436,13 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, /// 1. The icmp predicate is inverted /// 2. The select operands are reversed /// 3. The magnitude of C2 and C1 are flipped -static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, +static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal, Value *FalseVal, InstCombiner::BuilderTy &Builder) { - const ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition()); - if (!IC || !SI.getType()->isIntegerTy()) + // Only handle integer compares. Also, if this is a vector select, we need a + // vector compare. + if (!TrueVal->getType()->isIntOrIntVectorTy() || + TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy()) return nullptr; Value *CmpLHS = IC->getOperand(0); @@ -371,8 +496,8 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, bool NeedXor = (!IsEqualZero && OrOnFalseVal) || (IsEqualZero && OrOnTrueVal); bool NeedShift = C1Log != C2Log; - bool NeedZExtTrunc = Y->getType()->getIntegerBitWidth() != - V->getType()->getIntegerBitWidth(); + bool NeedZExtTrunc = Y->getType()->getScalarSizeInBits() != + V->getType()->getScalarSizeInBits(); // Make sure we don't create more instructions than we save. Value *Or = OrOnFalseVal ? FalseVal : TrueVal; @@ -447,8 +572,7 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, IntrinsicInst *II = cast<IntrinsicInst>(Count); // Explicitly clear the 'undef_on_zero' flag. IntrinsicInst *NewI = cast<IntrinsicInst>(II->clone()); - Type *Ty = NewI->getArgOperand(1)->getType(); - NewI->setArgOperand(1, Constant::getNullValue(Ty)); + NewI->setArgOperand(1, ConstantInt::getFalse(NewI->getContext())); Builder.Insert(NewI); return Builder.CreateZExtOrTrunc(NewI, ValueOnZero->getType()); } @@ -597,6 +721,9 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, /// Visit a SelectInst that has an ICmpInst as its first operand. Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI) { + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, Builder)) return NewSel; @@ -605,40 +732,52 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, ICmpInst::Predicate Pred = ICI->getPredicate(); Value *CmpLHS = ICI->getOperand(0); Value *CmpRHS = ICI->getOperand(1); - Value *TrueVal = SI.getTrueValue(); - Value *FalseVal = SI.getFalseValue(); // Transform (X >s -1) ? C1 : C2 --> ((X >>s 31) & (C2 - C1)) + C1 // and (X <s 0) ? C2 : C1 --> ((X >>s 31) & (C2 - C1)) + C1 // FIXME: Type and constness constraints could be lifted, but we have to // watch code size carefully. We should consider xor instead of // sub/add when we decide to do that. - if (IntegerType *Ty = dyn_cast<IntegerType>(CmpLHS->getType())) { - if (TrueVal->getType() == Ty) { - if (ConstantInt *Cmp = dyn_cast<ConstantInt>(CmpRHS)) { - ConstantInt *C1 = nullptr, *C2 = nullptr; - if (Pred == ICmpInst::ICMP_SGT && Cmp->isMinusOne()) { - C1 = dyn_cast<ConstantInt>(TrueVal); - C2 = dyn_cast<ConstantInt>(FalseVal); - } else if (Pred == ICmpInst::ICMP_SLT && Cmp->isZero()) { - C1 = dyn_cast<ConstantInt>(FalseVal); - C2 = dyn_cast<ConstantInt>(TrueVal); - } - if (C1 && C2) { + // TODO: Merge this with foldSelectICmpAnd somehow. + if (CmpLHS->getType()->isIntOrIntVectorTy() && + CmpLHS->getType() == TrueVal->getType()) { + const APInt *C1, *C2; + if (match(TrueVal, m_APInt(C1)) && match(FalseVal, m_APInt(C2))) { + ICmpInst::Predicate Pred = ICI->getPredicate(); + Value *X; + APInt Mask; + if (decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, X, Mask, false)) { + if (Mask.isSignMask()) { + assert(X == CmpLHS && "Expected to use the compare input directly"); + assert(ICmpInst::isEquality(Pred) && "Expected equality predicate"); + + if (Pred == ICmpInst::ICMP_NE) + std::swap(C1, C2); + // This shift results in either -1 or 0. - Value *AShr = Builder.CreateAShr(CmpLHS, Ty->getBitWidth() - 1); + Value *AShr = Builder.CreateAShr(X, Mask.getBitWidth() - 1); // Check if we can express the operation with a single or. - if (C2->isMinusOne()) - return replaceInstUsesWith(SI, Builder.CreateOr(AShr, C1)); + if (C2->isAllOnesValue()) + return replaceInstUsesWith(SI, Builder.CreateOr(AShr, *C1)); - Value *And = Builder.CreateAnd(AShr, C2->getValue() - C1->getValue()); - return replaceInstUsesWith(SI, Builder.CreateAdd(And, C1)); + Value *And = Builder.CreateAnd(AShr, *C2 - *C1); + return replaceInstUsesWith(SI, Builder.CreateAdd(And, + ConstantInt::get(And->getType(), *C1))); } } } } + { + const APInt *TrueValC, *FalseValC; + if (match(TrueVal, m_APInt(TrueValC)) && + match(FalseVal, m_APInt(FalseValC))) + if (Value *V = foldSelectICmpAnd(SI.getType(), ICI, *TrueValC, + *FalseValC, Builder)) + return replaceInstUsesWith(SI, V); + } + // NOTE: if we wanted to, this is where to detect integer MIN/MAX if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) { @@ -703,7 +842,7 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, } } - if (Value *V = foldSelectICmpAndOr(SI, TrueVal, FalseVal, Builder)) + if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder)) @@ -722,7 +861,6 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, /// Z = select X, Y, 0 /// /// because Y is not live in BB1/BB2. -/// static bool canSelectOperandBeMappingIntoPredBlock(const Value *V, const SelectInst &SI) { // If the value is a non-instruction value like a constant or argument, it @@ -864,78 +1002,6 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, return nullptr; } -/// If one of the constants is zero (we know they can't both be) and we have an -/// icmp instruction with zero, and we have an 'and' with the non-constant value -/// and a power of two we can turn the select into a shift on the result of the -/// 'and'. -static Value *foldSelectICmpAnd(const SelectInst &SI, APInt TrueVal, - APInt FalseVal, - InstCombiner::BuilderTy &Builder) { - const ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition()); - if (!IC || !IC->isEquality() || !SI.getType()->isIntegerTy()) - return nullptr; - - if (!match(IC->getOperand(1), m_Zero())) - return nullptr; - - ConstantInt *AndRHS; - Value *LHS = IC->getOperand(0); - if (!match(LHS, m_And(m_Value(), m_ConstantInt(AndRHS)))) - return nullptr; - - // If both select arms are non-zero see if we have a select of the form - // 'x ? 2^n + C : C'. Then we can offset both arms by C, use the logic - // for 'x ? 2^n : 0' and fix the thing up at the end. - APInt Offset(TrueVal.getBitWidth(), 0); - if (!TrueVal.isNullValue() && !FalseVal.isNullValue()) { - if ((TrueVal - FalseVal).isPowerOf2()) - Offset = FalseVal; - else if ((FalseVal - TrueVal).isPowerOf2()) - Offset = TrueVal; - else - return nullptr; - - // Adjust TrueVal and FalseVal to the offset. - TrueVal -= Offset; - FalseVal -= Offset; - } - - // Make sure the mask in the 'and' and one of the select arms is a power of 2. - if (!AndRHS->getValue().isPowerOf2() || - (!TrueVal.isPowerOf2() && !FalseVal.isPowerOf2())) - return nullptr; - - // Determine which shift is needed to transform result of the 'and' into the - // desired result. - const APInt &ValC = !TrueVal.isNullValue() ? TrueVal : FalseVal; - unsigned ValZeros = ValC.logBase2(); - unsigned AndZeros = AndRHS->getValue().logBase2(); - - // If types don't match we can still convert the select by introducing a zext - // or a trunc of the 'and'. The trunc case requires that all of the truncated - // bits are zero, we can figure that out by looking at the 'and' mask. - if (AndZeros >= ValC.getBitWidth()) - return nullptr; - - Value *V = Builder.CreateZExtOrTrunc(LHS, SI.getType()); - if (ValZeros > AndZeros) - V = Builder.CreateShl(V, ValZeros - AndZeros); - else if (ValZeros < AndZeros) - V = Builder.CreateLShr(V, AndZeros - ValZeros); - - // Okay, now we know that everything is set up, we just don't know whether we - // have a icmp_ne or icmp_eq and whether the true or false val is the zero. - bool ShouldNotVal = !TrueVal.isNullValue(); - ShouldNotVal ^= IC->getPredicate() == ICmpInst::ICMP_NE; - if (ShouldNotVal) - V = Builder.CreateXor(V, ValC); - - // Apply an offset if needed. - if (!Offset.isNullValue()) - V = Builder.CreateAdd(V, ConstantInt::get(V->getType(), Offset)); - return V; -} - /// Turn select C, (X + Y), (X - Y) --> (X + (select C, Y, (-Y))). /// This is even legal for FP. static Instruction *foldAddSubSelect(SelectInst &SI, @@ -1151,12 +1217,100 @@ static Instruction *foldSelectCmpBitcasts(SelectInst &Sel, return CastInst::CreateBitOrPointerCast(NewSel, Sel.getType()); } +/// Try to eliminate select instructions that test the returned flag of cmpxchg +/// instructions. +/// +/// If a select instruction tests the returned flag of a cmpxchg instruction and +/// selects between the returned value of the cmpxchg instruction its compare +/// operand, the result of the select will always be equal to its false value. +/// For example: +/// +/// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst +/// %1 = extractvalue { i64, i1 } %0, 1 +/// %2 = extractvalue { i64, i1 } %0, 0 +/// %3 = select i1 %1, i64 %compare, i64 %2 +/// ret i64 %3 +/// +/// The returned value of the cmpxchg instruction (%2) is the original value +/// located at %ptr prior to any update. If the cmpxchg operation succeeds, %2 +/// must have been equal to %compare. Thus, the result of the select is always +/// equal to %2, and the code can be simplified to: +/// +/// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst +/// %1 = extractvalue { i64, i1 } %0, 0 +/// ret i64 %1 +/// +static Instruction *foldSelectCmpXchg(SelectInst &SI) { + // A helper that determines if V is an extractvalue instruction whose + // aggregate operand is a cmpxchg instruction and whose single index is equal + // to I. If such conditions are true, the helper returns the cmpxchg + // instruction; otherwise, a nullptr is returned. + auto isExtractFromCmpXchg = [](Value *V, unsigned I) -> AtomicCmpXchgInst * { + auto *Extract = dyn_cast<ExtractValueInst>(V); + if (!Extract) + return nullptr; + if (Extract->getIndices()[0] != I) + return nullptr; + return dyn_cast<AtomicCmpXchgInst>(Extract->getAggregateOperand()); + }; + + // If the select has a single user, and this user is a select instruction that + // we can simplify, skip the cmpxchg simplification for now. + if (SI.hasOneUse()) + if (auto *Select = dyn_cast<SelectInst>(SI.user_back())) + if (Select->getCondition() == SI.getCondition()) + if (Select->getFalseValue() == SI.getTrueValue() || + Select->getTrueValue() == SI.getFalseValue()) + return nullptr; + + // Ensure the select condition is the returned flag of a cmpxchg instruction. + auto *CmpXchg = isExtractFromCmpXchg(SI.getCondition(), 1); + if (!CmpXchg) + return nullptr; + + // Check the true value case: The true value of the select is the returned + // value of the same cmpxchg used by the condition, and the false value is the + // cmpxchg instruction's compare operand. + if (auto *X = isExtractFromCmpXchg(SI.getTrueValue(), 0)) + if (X == CmpXchg && X->getCompareOperand() == SI.getFalseValue()) { + SI.setTrueValue(SI.getFalseValue()); + return &SI; + } + + // Check the false value case: The false value of the select is the returned + // value of the same cmpxchg used by the condition, and the true value is the + // cmpxchg instruction's compare operand. + if (auto *X = isExtractFromCmpXchg(SI.getFalseValue(), 0)) + if (X == CmpXchg && X->getCompareOperand() == SI.getTrueValue()) { + SI.setTrueValue(SI.getFalseValue()); + return &SI; + } + + return nullptr; +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); Type *SelType = SI.getType(); + // FIXME: Remove this workaround when freeze related patches are done. + // For select with undef operand which feeds into an equality comparison, + // don't simplify it so loop unswitch can know the equality comparison + // may have an undef operand. This is a workaround for PR31652 caused by + // descrepancy about branch on undef between LoopUnswitch and GVN. + if (isa<UndefValue>(TrueVal) || isa<UndefValue>(FalseVal)) { + if (llvm::any_of(SI.users(), [&](User *U) { + ICmpInst *CI = dyn_cast<ICmpInst>(U); + if (CI && CI->isEquality()) + return true; + return false; + })) { + return nullptr; + } + } + if (Value *V = SimplifySelectInst(CondVal, TrueVal, FalseVal, SQ.getWithInstruction(&SI))) return replaceInstUsesWith(SI, V); @@ -1246,12 +1400,6 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } - if (ConstantInt *TrueValC = dyn_cast<ConstantInt>(TrueVal)) - if (ConstantInt *FalseValC = dyn_cast<ConstantInt>(FalseVal)) - if (Value *V = foldSelectICmpAnd(SI, TrueValC->getValue(), - FalseValC->getValue(), Builder)) - return replaceInstUsesWith(SI, V); - // See if we are selecting two values based on a comparison of the two values. if (FCmpInst *FCI = dyn_cast<FCmpInst>(CondVal)) { if (FCI->getOperand(0) == TrueVal && FCI->getOperand(1) == FalseVal) { @@ -1373,9 +1521,17 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { auto SPF = SPR.Flavor; if (SelectPatternResult::isMinOrMax(SPF)) { - // Canonicalize so that type casts are outside select patterns. - if (LHS->getType()->getPrimitiveSizeInBits() != - SelType->getPrimitiveSizeInBits()) { + // Canonicalize so that + // - type casts are outside select patterns. + // - float clamp is transformed to min/max pattern + + bool IsCastNeeded = LHS->getType() != SelType; + Value *CmpLHS = cast<CmpInst>(CondVal)->getOperand(0); + Value *CmpRHS = cast<CmpInst>(CondVal)->getOperand(1); + if (IsCastNeeded || + (LHS->getType()->isFPOrFPVectorTy() && + ((CmpLHS != LHS && CmpLHS != RHS) || + (CmpRHS != LHS && CmpRHS != RHS)))) { CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF, SPR.Ordered); Value *Cmp; @@ -1388,10 +1544,12 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Cmp = Builder.CreateFCmp(Pred, LHS, RHS); } - Value *NewSI = Builder.CreateCast( - CastOp, Builder.CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI), - SelType); - return replaceInstUsesWith(SI, NewSI); + Value *NewSI = Builder.CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI); + if (!IsCastNeeded) + return replaceInstUsesWith(SI, NewSI); + + Value *NewCast = Builder.CreateCast(CastOp, NewSI, SelType); + return replaceInstUsesWith(SI, NewCast); } } @@ -1485,6 +1643,46 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } + // Try to simplify a binop sandwiched between 2 selects with the same + // condition. + // select(C, binop(select(C, X, Y), W), Z) -> select(C, binop(X, W), Z) + BinaryOperator *TrueBO; + if (match(TrueVal, m_OneUse(m_BinOp(TrueBO)))) { + if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(0))) { + if (TrueBOSI->getCondition() == CondVal) { + TrueBO->setOperand(0, TrueBOSI->getTrueValue()); + Worklist.Add(TrueBO); + return &SI; + } + } + if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(1))) { + if (TrueBOSI->getCondition() == CondVal) { + TrueBO->setOperand(1, TrueBOSI->getTrueValue()); + Worklist.Add(TrueBO); + return &SI; + } + } + } + + // select(C, Z, binop(select(C, X, Y), W)) -> select(C, Z, binop(Y, W)) + BinaryOperator *FalseBO; + if (match(FalseVal, m_OneUse(m_BinOp(FalseBO)))) { + if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(0))) { + if (FalseBOSI->getCondition() == CondVal) { + FalseBO->setOperand(0, FalseBOSI->getFalseValue()); + Worklist.Add(FalseBO); + return &SI; + } + } + if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(1))) { + if (FalseBOSI->getCondition() == CondVal) { + FalseBO->setOperand(1, FalseBOSI->getFalseValue()); + Worklist.Add(FalseBO); + return &SI; + } + } + } + if (BinaryOperator::isNot(CondVal)) { SI.setOperand(0, BinaryOperator::getNotArgument(CondVal)); SI.setOperand(1, FalseVal); @@ -1501,10 +1699,6 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return replaceInstUsesWith(SI, V); return &SI; } - - if (isa<ConstantAggregateZero>(CondVal)) { - return replaceInstUsesWith(SI, FalseVal); - } } // See if we can determine the result of this select based on a dominating @@ -1515,9 +1709,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (PBI && PBI->isConditional() && PBI->getSuccessor(0) != PBI->getSuccessor(1) && (PBI->getSuccessor(0) == Parent || PBI->getSuccessor(1) == Parent)) { - bool CondIsFalse = PBI->getSuccessor(1) == Parent; + bool CondIsTrue = PBI->getSuccessor(0) == Parent; Optional<bool> Implication = isImpliedCondition( - PBI->getCondition(), SI.getCondition(), DL, CondIsFalse); + PBI->getCondition(), SI.getCondition(), DL, CondIsTrue); if (Implication) { Value *V = *Implication ? TrueVal : FalseVal; return replaceInstUsesWith(SI, V); @@ -1542,5 +1736,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *BitCastSel = foldSelectCmpBitcasts(SI, Builder)) return BitCastSel; + // Simplify selects that test the returned flag of cmpxchg instructions. + if (Instruction *Select = foldSelectCmpXchg(SI)) + return Select; + return nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index 7ed141c7fd79..44bbb84686ab 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -310,6 +310,40 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, } } +// If this is a bitwise operator or add with a constant RHS we might be able +// to pull it through a shift. +static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift, + BinaryOperator *BO, + const APInt &C) { + bool IsValid = true; // Valid only for And, Or Xor, + bool HighBitSet = false; // Transform ifhigh bit of constant set? + + switch (BO->getOpcode()) { + default: IsValid = false; break; // Do not perform transform! + case Instruction::Add: + IsValid = Shift.getOpcode() == Instruction::Shl; + break; + case Instruction::Or: + case Instruction::Xor: + HighBitSet = false; + break; + case Instruction::And: + HighBitSet = true; + break; + } + + // If this is a signed shift right, and the high bit is modified + // by the logical operation, do not perform the transformation. + // The HighBitSet boolean indicates the value of the high bit of + // the constant which would cause it to be modified for this + // operation. + // + if (IsValid && Shift.getOpcode() == Instruction::AShr) + IsValid = C.isNegative() == HighBitSet; + + return IsValid; +} + Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I) { bool isLeftShift = I.getOpcode() == Instruction::Shl; @@ -470,35 +504,11 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // If the operand is a bitwise operator with a constant RHS, and the // shift is the only use, we can pull it out of the shift. - if (ConstantInt *Op0C = dyn_cast<ConstantInt>(Op0BO->getOperand(1))) { - bool isValid = true; // Valid only for And, Or, Xor - bool highBitSet = false; // Transform if high bit of constant set? - - switch (Op0BO->getOpcode()) { - default: isValid = false; break; // Do not perform transform! - case Instruction::Add: - isValid = isLeftShift; - break; - case Instruction::Or: - case Instruction::Xor: - highBitSet = false; - break; - case Instruction::And: - highBitSet = true; - break; - } - - // If this is a signed shift right, and the high bit is modified - // by the logical operation, do not perform the transformation. - // The highBitSet boolean indicates the value of the high bit of - // the constant which would cause it to be modified for this - // operation. - // - if (isValid && I.getOpcode() == Instruction::AShr) - isValid = Op0C->getValue()[TypeBits-1] == highBitSet; - - if (isValid) { - Constant *NewRHS = ConstantExpr::get(I.getOpcode(), Op0C, Op1); + const APInt *Op0C; + if (match(Op0BO->getOperand(1), m_APInt(Op0C))) { + if (canShiftBinOpWithConstantRHS(I, Op0BO, *Op0C)) { + Constant *NewRHS = ConstantExpr::get(I.getOpcode(), + cast<Constant>(Op0BO->getOperand(1)), Op1); Value *NewShift = Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(0), Op1); @@ -508,6 +518,67 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, NewRHS); } } + + // If the operand is a subtract with a constant LHS, and the shift + // is the only use, we can pull it out of the shift. + // This folds (shl (sub C1, X), C2) -> (sub (C1 << C2), (shl X, C2)) + if (isLeftShift && Op0BO->getOpcode() == Instruction::Sub && + match(Op0BO->getOperand(0), m_APInt(Op0C))) { + Constant *NewRHS = ConstantExpr::get(I.getOpcode(), + cast<Constant>(Op0BO->getOperand(0)), Op1); + + Value *NewShift = Builder.CreateShl(Op0BO->getOperand(1), Op1); + NewShift->takeName(Op0BO); + + return BinaryOperator::CreateSub(NewRHS, NewShift); + } + } + + // If we have a select that conditionally executes some binary operator, + // see if we can pull it the select and operator through the shift. + // + // For example, turning: + // shl (select C, (add X, C1), X), C2 + // Into: + // Y = shl X, C2 + // select C, (add Y, C1 << C2), Y + Value *Cond; + BinaryOperator *TBO; + Value *FalseVal; + if (match(Op0, m_Select(m_Value(Cond), m_OneUse(m_BinOp(TBO)), + m_Value(FalseVal)))) { + const APInt *C; + if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal && + match(TBO->getOperand(1), m_APInt(C)) && + canShiftBinOpWithConstantRHS(I, TBO, *C)) { + Constant *NewRHS = ConstantExpr::get(I.getOpcode(), + cast<Constant>(TBO->getOperand(1)), Op1); + + Value *NewShift = + Builder.CreateBinOp(I.getOpcode(), FalseVal, Op1); + Value *NewOp = Builder.CreateBinOp(TBO->getOpcode(), NewShift, + NewRHS); + return SelectInst::Create(Cond, NewOp, NewShift); + } + } + + BinaryOperator *FBO; + Value *TrueVal; + if (match(Op0, m_Select(m_Value(Cond), m_Value(TrueVal), + m_OneUse(m_BinOp(FBO))))) { + const APInt *C; + if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal && + match(FBO->getOperand(1), m_APInt(C)) && + canShiftBinOpWithConstantRHS(I, FBO, *C)) { + Constant *NewRHS = ConstantExpr::get(I.getOpcode(), + cast<Constant>(FBO->getOperand(1)), Op1); + + Value *NewShift = + Builder.CreateBinOp(I.getOpcode(), TrueVal, Op1); + Value *NewOp = Builder.CreateBinOp(FBO->getOpcode(), NewShift, + NewRHS); + return SelectInst::Create(Cond, NewShift, NewOp); + } } } @@ -543,8 +614,8 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { return new ZExtInst(Builder.CreateShl(X, ShAmt), Ty); } - // (X >>u C) << C --> X & (-1 << C) - if (match(Op0, m_LShr(m_Value(X), m_Specific(Op1)))) { + // (X >> C) << C --> X & (-1 << C) + if (match(Op0, m_Shr(m_Value(X), m_Specific(Op1)))) { APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt)); return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); } @@ -680,6 +751,15 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); } + if (match(Op0, m_OneUse(m_ZExt(m_Value(X)))) && + (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) { + assert(ShAmt < X->getType()->getScalarSizeInBits() && + "Big shift not simplified to zero?"); + // lshr (zext iM X to iN), C --> zext (lshr X, C) to iN + Value *NewLShr = Builder.CreateLShr(X, ShAmt); + return new ZExtInst(NewLShr, Ty); + } + if (match(Op0, m_SExt(m_Value(X))) && (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) { // Are we moving the sign bit to the low bit and widening with high zeros? @@ -778,6 +858,15 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { return BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum)); } + if (match(Op0, m_OneUse(m_SExt(m_Value(X)))) && + (Ty->isVectorTy() || shouldChangeType(Ty, X->getType()))) { + // ashr (sext X), C --> sext (ashr X, C') + Type *SrcTy = X->getType(); + ShAmt = std::min(ShAmt, SrcTy->getScalarSizeInBits() - 1); + Value *NewSh = Builder.CreateAShr(X, ConstantInt::get(SrcTy, ShAmt)); + return new SExtInst(NewSh, Ty); + } + // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) { diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index a20f474cbf40..a2e757cb4273 100644 --- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -396,50 +396,50 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, /// If the high-bits of an ADD/SUB are not demanded, then we do not care /// about the high bits of the operands. unsigned NLZ = DemandedMask.countLeadingZeros(); - if (NLZ > 0) { - // Right fill the mask of bits for this ADD/SUB to demand the most - // significant bit and all those below it. - APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ)); - if (ShrinkDemandedConstant(I, 0, DemandedFromOps) || - SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) || - ShrinkDemandedConstant(I, 1, DemandedFromOps) || - SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) { + // Right fill the mask of bits for this ADD/SUB to demand the most + // significant bit and all those below it. + APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ)); + if (ShrinkDemandedConstant(I, 0, DemandedFromOps) || + SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) || + ShrinkDemandedConstant(I, 1, DemandedFromOps) || + SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) { + if (NLZ > 0) { // Disable the nsw and nuw flags here: We can no longer guarantee that // we won't wrap after simplification. Removing the nsw/nuw flags is // legal here because the top bit is not demanded. BinaryOperator &BinOP = *cast<BinaryOperator>(I); BinOP.setHasNoSignedWrap(false); BinOP.setHasNoUnsignedWrap(false); - return I; } - - // If we are known to be adding/subtracting zeros to every bit below - // the highest demanded bit, we just return the other side. - if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) - return I->getOperand(0); - // We can't do this with the LHS for subtraction, unless we are only - // demanding the LSB. - if ((I->getOpcode() == Instruction::Add || - DemandedFromOps.isOneValue()) && - DemandedFromOps.isSubsetOf(LHSKnown.Zero)) - return I->getOperand(1); + return I; } - // Otherwise just hand the add/sub off to computeKnownBits to fill in - // the known zeros and ones. - computeKnownBits(V, Known, Depth, CxtI); + // If we are known to be adding/subtracting zeros to every bit below + // the highest demanded bit, we just return the other side. + if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) + return I->getOperand(0); + // We can't do this with the LHS for subtraction, unless we are only + // demanding the LSB. + if ((I->getOpcode() == Instruction::Add || + DemandedFromOps.isOneValue()) && + DemandedFromOps.isSubsetOf(LHSKnown.Zero)) + return I->getOperand(1); + + // Otherwise just compute the known bits of the result. + bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); + Known = KnownBits::computeForAddSub(I->getOpcode() == Instruction::Add, + NSW, LHSKnown, RHSKnown); break; } case Instruction::Shl: { const APInt *SA; if (match(I->getOperand(1), m_APInt(SA))) { const APInt *ShrAmt; - if (match(I->getOperand(0), m_Shr(m_Value(), m_APInt(ShrAmt)))) { - Instruction *Shr = cast<Instruction>(I->getOperand(0)); - if (Value *R = simplifyShrShlDemandedBits( - Shr, *ShrAmt, I, *SA, DemandedMask, Known)) - return R; - } + if (match(I->getOperand(0), m_Shr(m_Value(), m_APInt(ShrAmt)))) + if (Instruction *Shr = dyn_cast<Instruction>(I->getOperand(0))) + if (Value *R = simplifyShrShlDemandedBits(Shr, *ShrAmt, I, *SA, + DemandedMask, Known)) + return R; uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); APInt DemandedMaskIn(DemandedMask.lshr(ShiftAmt)); @@ -521,26 +521,25 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) return I; + unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI); + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - // Compute the new bits that are at the top now. - APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt)); + // Compute the new bits that are at the top now plus sign bits. + APInt HighBits(APInt::getHighBitsSet( + BitWidth, std::min(SignBits + ShiftAmt - 1, BitWidth))); Known.Zero.lshrInPlace(ShiftAmt); Known.One.lshrInPlace(ShiftAmt); - // Handle the sign bits. - APInt SignMask(APInt::getSignMask(BitWidth)); - // Adjust to where it is now in the mask. - SignMask.lshrInPlace(ShiftAmt); - // If the input sign bit is known to be zero, or if none of the top bits // are demanded, turn this into an unsigned shift right. - if (BitWidth <= ShiftAmt || Known.Zero[BitWidth-ShiftAmt-1] || + assert(BitWidth > ShiftAmt && "Shift amount not saturated?"); + if (Known.Zero[BitWidth-ShiftAmt-1] || !DemandedMask.intersects(HighBits)) { BinaryOperator *LShr = BinaryOperator::CreateLShr(I->getOperand(0), I->getOperand(1)); LShr->setIsExact(cast<BinaryOperator>(I)->isExact()); return InsertNewInstWith(LShr, *I); - } else if (Known.One.intersects(SignMask)) { // New bits are known one. + } else if (Known.One[BitWidth-ShiftAmt-1]) { // New bits are known one. Known.One |= HighBits; } } @@ -993,22 +992,23 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, break; } + // The element inserted overwrites whatever was there, so the input demanded + // set is simpler than the output set. + unsigned IdxNo = Idx->getZExtValue(); + APInt PreInsertDemandedElts = DemandedElts; + if (IdxNo < VWidth) + PreInsertDemandedElts.clearBit(IdxNo); + TmpV = SimplifyDemandedVectorElts(I->getOperand(0), PreInsertDemandedElts, + UndefElts, Depth + 1); + if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; } + // If this is inserting an element that isn't demanded, remove this // insertelement. - unsigned IdxNo = Idx->getZExtValue(); if (IdxNo >= VWidth || !DemandedElts[IdxNo]) { Worklist.Add(I); return I->getOperand(0); } - // Otherwise, the element inserted overwrites whatever was there, so the - // input demanded set is simpler than the output set. - APInt DemandedElts2 = DemandedElts; - DemandedElts2.clearBit(IdxNo); - TmpV = SimplifyDemandedVectorElts(I->getOperand(0), DemandedElts2, - UndefElts, Depth + 1); - if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; } - // The inserted element is defined. UndefElts.clearBit(IdxNo); break; diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index dd71a31b644b..65a96b965227 100644 --- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -13,10 +13,33 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include <cassert> +#include <cstdint> +#include <iterator> +#include <utility> + using namespace llvm; using namespace PatternMatch; @@ -90,7 +113,7 @@ Instruction *InstCombiner::scalarizePHI(ExtractElementInst &EI, PHINode *PN) { // Verify that this PHI user has one use, which is the PHI itself, // and that it is a binary operation which is cheap to scalarize. - // otherwise return NULL. + // otherwise return nullptr. if (!PHIUser->hasOneUse() || !(PHIUser->user_back() == PN) || !(isa<BinaryOperator>(PHIUser)) || !cheapToScalarize(PHIUser, true)) return nullptr; @@ -255,39 +278,6 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { Worklist.AddValue(EE); return CastInst::Create(CI->getOpcode(), EE, EI.getType()); } - } else if (SelectInst *SI = dyn_cast<SelectInst>(I)) { - if (SI->hasOneUse()) { - // TODO: For a select on vectors, it might be useful to do this if it - // has multiple extractelement uses. For vector select, that seems to - // fight the vectorizer. - - // If we are extracting an element from a vector select or a select on - // vectors, create a select on the scalars extracted from the vector - // arguments. - Value *TrueVal = SI->getTrueValue(); - Value *FalseVal = SI->getFalseValue(); - - Value *Cond = SI->getCondition(); - if (Cond->getType()->isVectorTy()) { - Cond = Builder.CreateExtractElement(Cond, - EI.getIndexOperand(), - Cond->getName() + ".elt"); - } - - Value *V1Elem - = Builder.CreateExtractElement(TrueVal, - EI.getIndexOperand(), - TrueVal->getName() + ".elt"); - - Value *V2Elem - = Builder.CreateExtractElement(FalseVal, - EI.getIndexOperand(), - FalseVal->getName() + ".elt"); - return SelectInst::Create(Cond, - V1Elem, - V2Elem, - SI->getName() + ".elt"); - } } } return nullptr; @@ -454,7 +444,7 @@ static void replaceExtractElements(InsertElementInst *InsElt, /// /// Note: we intentionally don't try to fold earlier shuffles since they have /// often been chosen carefully to be efficiently implementable on the target. -typedef std::pair<Value *, Value *> ShuffleOps; +using ShuffleOps = std::pair<Value *, Value *>; static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<Constant *> &Mask, @@ -615,20 +605,26 @@ static Instruction *foldInsSequenceIntoBroadcast(InsertElementInst &InsElt) { Value *SplatVal = InsElt.getOperand(1); InsertElementInst *CurrIE = &InsElt; SmallVector<bool, 16> ElementPresent(NumElements, false); + InsertElementInst *FirstIE = nullptr; // Walk the chain backwards, keeping track of which indices we inserted into, // until we hit something that isn't an insert of the splatted value. while (CurrIE) { - ConstantInt *Idx = dyn_cast<ConstantInt>(CurrIE->getOperand(2)); + auto *Idx = dyn_cast<ConstantInt>(CurrIE->getOperand(2)); if (!Idx || CurrIE->getOperand(1) != SplatVal) return nullptr; - // Check none of the intermediate steps have any additional uses. - if ((CurrIE != &InsElt) && !CurrIE->hasOneUse()) + auto *NextIE = dyn_cast<InsertElementInst>(CurrIE->getOperand(0)); + // Check none of the intermediate steps have any additional uses, except + // for the root insertelement instruction, which can be re-used, if it + // inserts at position 0. + if (CurrIE != &InsElt && + (!CurrIE->hasOneUse() && (NextIE != nullptr || !Idx->isZero()))) return nullptr; ElementPresent[Idx->getZExtValue()] = true; - CurrIE = dyn_cast<InsertElementInst>(CurrIE->getOperand(0)); + FirstIE = CurrIE; + CurrIE = NextIE; } // Make sure we've seen an insert into every element. @@ -636,9 +632,14 @@ static Instruction *foldInsSequenceIntoBroadcast(InsertElementInst &InsElt) { return nullptr; // All right, create the insert + shuffle. - Instruction *InsertFirst = InsertElementInst::Create( - UndefValue::get(VT), SplatVal, - ConstantInt::get(Type::getInt32Ty(InsElt.getContext()), 0), "", &InsElt); + Instruction *InsertFirst; + if (cast<ConstantInt>(FirstIE->getOperand(2))->isZero()) + InsertFirst = FirstIE; + else + InsertFirst = InsertElementInst::Create( + UndefValue::get(VT), SplatVal, + ConstantInt::get(Type::getInt32Ty(InsElt.getContext()), 0), + "", &InsElt); Constant *ZeroMask = ConstantAggregateZero::get( VectorType::get(Type::getInt32Ty(InsElt.getContext()), NumElements)); @@ -780,6 +781,10 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { Value *ScalarOp = IE.getOperand(1); Value *IdxOp = IE.getOperand(2); + if (auto *V = SimplifyInsertElementInst( + VecOp, ScalarOp, IdxOp, SQ.getWithInstruction(&IE))) + return replaceInstUsesWith(IE, V); + // Inserting an undef or into an undefined place, remove this. if (isa<UndefValue>(ScalarOp) || isa<UndefValue>(IdxOp)) replaceInstUsesWith(IE, VecOp); @@ -1007,15 +1012,13 @@ InstCombiner::EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { // Mask.size() does not need to be equal to the number of vector elements. assert(V->getType()->isVectorTy() && "can't reorder non-vector elements"); - if (isa<UndefValue>(V)) { - return UndefValue::get(VectorType::get(V->getType()->getScalarType(), - Mask.size())); - } - if (isa<ConstantAggregateZero>(V)) { - return ConstantAggregateZero::get( - VectorType::get(V->getType()->getScalarType(), - Mask.size())); - } + Type *EltTy = V->getType()->getScalarType(); + if (isa<UndefValue>(V)) + return UndefValue::get(VectorType::get(EltTy, Mask.size())); + + if (isa<ConstantAggregateZero>(V)) + return ConstantAggregateZero::get(VectorType::get(EltTy, Mask.size())); + if (Constant *C = dyn_cast<Constant>(V)) { SmallVector<Constant *, 16> MaskValues; for (int i = 0, e = Mask.size(); i != e; ++i) { @@ -1153,9 +1156,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (Value *V = SimplifyDemandedVectorElts(&SVI, AllOnesEltMask, UndefElts)) { if (V != &SVI) return replaceInstUsesWith(SVI, V); - LHS = SVI.getOperand(0); - RHS = SVI.getOperand(1); - MadeChange = true; + return &SVI; } unsigned LHSWidth = LHS->getType()->getVectorNumElements(); @@ -1446,7 +1447,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { eltMask = Mask[i]-LHSWidth; // If LHS's width is changed, shift the mask value accordingly. - // If newRHS == NULL, i.e. LHSOp0 == RHSOp0, we want to remap any + // If newRHS == nullptr, i.e. LHSOp0 == RHSOp0, we want to remap any // references from RHSOp0 to LHSOp0, so we don't need to shift the mask. // If newRHS == newLHS, we want to remap any references from newRHS to // newLHS so that we can properly identify splats that may occur due to diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp index c7766568fd9d..b332e75c7feb 100644 --- a/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -34,10 +34,14 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" -#include "llvm-c/Initialization.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TinyPtrVector.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" @@ -48,24 +52,56 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DIBuilder.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" #include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/CBindingWrapping.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugCounter.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombine.h" +#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> -#include <climits> +#include <cassert> +#include <cstdint> +#include <memory> +#include <string> +#include <utility> + using namespace llvm; using namespace llvm::PatternMatch; @@ -78,6 +114,8 @@ STATISTIC(NumSunkInst , "Number of instructions sunk"); STATISTIC(NumExpand, "Number of expansions"); STATISTIC(NumFactor , "Number of factorizations"); STATISTIC(NumReassoc , "Number of reassociations"); +DEBUG_COUNTER(VisitCounter, "instcombine-visit", + "Controls which instructions are visited"); static cl::opt<bool> EnableExpensiveCombines("expensive-combines", @@ -87,6 +125,16 @@ static cl::opt<unsigned> MaxArraySize("instcombine-maxarray-size", cl::init(1024), cl::desc("Maximum array size considered when doing a combine")); +// FIXME: Remove this flag when it is no longer necessary to convert +// llvm.dbg.declare to avoid inaccurate debug info. Setting this to false +// increases variable availability at the cost of accuracy. Variables that +// cannot be promoted by mem2reg or SROA will be described as living in memory +// for their entire lifetime. However, passes like DSE and instcombine can +// delete stores to the alloca, leading to misleading and inaccurate debug +// information. This flag can be removed when those passes are fixed. +static cl::opt<unsigned> ShouldLowerDbgDeclare("instcombine-lower-dbg-declare", + cl::Hidden, cl::init(true)); + Value *InstCombiner::EmitGEPOffset(User *GEP) { return llvm::EmitGEPOffset(&Builder, DL, GEP); } @@ -381,7 +429,7 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { // No further simplifications. return Changed; - } while (1); + } while (true); } /// Return whether "X LOp (Y ROp Z)" is always equal to @@ -704,36 +752,37 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { } } - // (op (select (a, c, b)), (select (a, d, b))) -> (select (a, (op c, d), 0)) - // (op (select (a, b, c)), (select (a, b, d))) -> (select (a, 0, (op c, d))) - if (auto *SI0 = dyn_cast<SelectInst>(LHS)) { - if (auto *SI1 = dyn_cast<SelectInst>(RHS)) { - if (SI0->getCondition() == SI1->getCondition()) { - Value *SI = nullptr; - if (Value *V = - SimplifyBinOp(TopLevelOpcode, SI0->getFalseValue(), - SI1->getFalseValue(), SQ.getWithInstruction(&I))) - SI = Builder.CreateSelect(SI0->getCondition(), - Builder.CreateBinOp(TopLevelOpcode, - SI0->getTrueValue(), - SI1->getTrueValue()), - V); - if (Value *V = - SimplifyBinOp(TopLevelOpcode, SI0->getTrueValue(), - SI1->getTrueValue(), SQ.getWithInstruction(&I))) - SI = Builder.CreateSelect( - SI0->getCondition(), V, - Builder.CreateBinOp(TopLevelOpcode, SI0->getFalseValue(), - SI1->getFalseValue())); - if (SI) { - SI->takeName(&I); - return SI; - } - } - } + return SimplifySelectsFeedingBinaryOp(I, LHS, RHS); +} + +Value *InstCombiner::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, + Value *LHS, Value *RHS) { + Instruction::BinaryOps Opcode = I.getOpcode(); + // (op (select (a, b, c)), (select (a, d, e))) -> (select (a, (op b, d), (op + // c, e))) + Value *A, *B, *C, *D, *E; + Value *SI = nullptr; + if (match(LHS, m_Select(m_Value(A), m_Value(B), m_Value(C))) && + match(RHS, m_Select(m_Specific(A), m_Value(D), m_Value(E)))) { + bool SelectsHaveOneUse = LHS->hasOneUse() && RHS->hasOneUse(); + BuilderTy::FastMathFlagGuard Guard(Builder); + if (isa<FPMathOperator>(&I)) + Builder.setFastMathFlags(I.getFastMathFlags()); + + Value *V1 = SimplifyBinOp(Opcode, C, E, SQ.getWithInstruction(&I)); + Value *V2 = SimplifyBinOp(Opcode, B, D, SQ.getWithInstruction(&I)); + if (V1 && V2) + SI = Builder.CreateSelect(A, V2, V1); + else if (V2 && SelectsHaveOneUse) + SI = Builder.CreateSelect(A, V2, Builder.CreateBinOp(Opcode, C, E)); + else if (V1 && SelectsHaveOneUse) + SI = Builder.CreateSelect(A, Builder.CreateBinOp(Opcode, B, D), V1); + + if (SI) + SI->takeName(&I); } - return nullptr; + return SI; } /// Given a 'sub' instruction, return the RHS of the instruction if the LHS is a @@ -1158,7 +1207,7 @@ Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { // Parent - initially null, but after drilling down notes where Op came from. // In the example above, Parent is (Val, 0) when Op is M1, because M1 is the // 0'th operand of Val. - std::pair<Instruction*, unsigned> Parent; + std::pair<Instruction *, unsigned> Parent; // Set if the transform requires a descaling at deeper levels that doesn't // overflow. @@ -1168,7 +1217,6 @@ Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { int32_t logScale = Scale.exactLogBase2(); for (;; Op = Parent.first->getOperand(Parent.second)) { // Drill down - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op)) { // If Op is a constant divisible by Scale then descale to the quotient. APInt Quotient(Scale), Remainder(Scale); // Init ensures right bitwidth. @@ -1183,7 +1231,6 @@ Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { } if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op)) { - if (BO->getOpcode() == Instruction::Mul) { // Multiplication. NoSignedWrap = BO->hasNoSignedWrap(); @@ -1358,7 +1405,7 @@ Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { // Move up one level in the expression. assert(Ancestor->hasOneUse() && "Drilled down when more than one use!"); Ancestor = Ancestor->user_back(); - } while (1); + } while (true); } /// \brief Creates node of binary operation with the same attributes as the @@ -1605,7 +1652,6 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // Combine Indices - If the source pointer to this getelementptr instruction // is a getelementptr instruction, combine the indices of the two // getelementptr instructions into a single instruction. - // if (GEPOperator *Src = dyn_cast<GEPOperator>(PtrOp)) { if (!shouldMergeGEPs(*cast<GEPOperator>(&GEP), *Src)) return nullptr; @@ -1630,7 +1676,6 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (EndsWithSequential) { // Replace: gep (gep %P, long B), long A, ... // With: T = long A+B; gep %P, T, ... - // Value *SO1 = Src->getOperand(Src->getNumOperands()-1); Value *GO1 = GEP.getOperand(1); @@ -2053,8 +2098,6 @@ static bool isAllocSiteRemovable(Instruction *AI, return false; LLVM_FALLTHROUGH; } - case Intrinsic::dbg_declare: - case Intrinsic::dbg_value: case Intrinsic::invariant_start: case Intrinsic::invariant_end: case Intrinsic::lifetime_start: @@ -2090,6 +2133,16 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { // to null and free calls, delete the calls and replace the comparisons with // true or false as appropriate. SmallVector<WeakTrackingVH, 64> Users; + + // If we are removing an alloca with a dbg.declare, insert dbg.value calls + // before each store. + TinyPtrVector<DbgInfoIntrinsic *> DIIs; + std::unique_ptr<DIBuilder> DIB; + if (isa<AllocaInst>(MI)) { + DIIs = FindDbgAddrUses(&MI); + DIB.reset(new DIBuilder(*MI.getModule(), /*AllowUnresolved=*/false)); + } + if (isAllocSiteRemovable(&MI, Users, &TLI)) { for (unsigned i = 0, e = Users.size(); i != e; ++i) { // Lowering all @llvm.objectsize calls first because they may @@ -2122,6 +2175,9 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { } else if (isa<BitCastInst>(I) || isa<GetElementPtrInst>(I) || isa<AddrSpaceCastInst>(I)) { replaceInstUsesWith(*I, UndefValue::get(I->getType())); + } else if (auto *SI = dyn_cast<StoreInst>(I)) { + for (auto *DII : DIIs) + ConvertDebugDeclareToDebugValue(DII, SI, *DIB); } eraseInstFromFunction(*I); } @@ -2133,6 +2189,10 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { InvokeInst::Create(F, II->getNormalDest(), II->getUnwindDest(), None, "", II->getParent()); } + + for (auto *DII : DIIs) + eraseInstFromFunction(*DII); + return eraseInstFromFunction(MI); } return nullptr; @@ -2195,7 +2255,6 @@ tryToMoveFreeBeforeNullTest(CallInst &FI) { return &FI; } - Instruction *InstCombiner::visitFree(CallInst &FI) { Value *Op = FI.getArgOperand(0); @@ -2258,10 +2317,9 @@ Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { // If the condition is irrelevant, remove the use so that other // transforms on the condition become more effective. - if (BI.isConditional() && - BI.getSuccessor(0) == BI.getSuccessor(1) && - !isa<UndefValue>(BI.getCondition())) { - BI.setCondition(UndefValue::get(BI.getCondition()->getType())); + if (BI.isConditional() && !isa<ConstantInt>(BI.getCondition()) && + BI.getSuccessor(0) == BI.getSuccessor(1)) { + BI.setCondition(ConstantInt::getFalse(BI.getCondition()->getType())); return &BI; } @@ -2881,6 +2939,9 @@ bool InstCombiner::run() { continue; } + if (!DebugCounter::shouldExecute(VisitCounter)) + continue; + // Instruction isn't dead, see if we can constant propagate it. if (!I->use_empty() && (I->getNumOperands() == 0 || isa<Constant>(I->getOperand(0)))) { @@ -3027,7 +3088,6 @@ bool InstCombiner::run() { /// them to the worklist (this significantly speeds up instcombine on code where /// many instructions are dead or constant). Additionally, if we find a branch /// whose condition is a known constant, we only visit the reachable successors. -/// static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, SmallPtrSetImpl<BasicBlock *> &Visited, InstCombineWorklist &ICWorklist, @@ -3053,6 +3113,7 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, if (isInstructionTriviallyDead(Inst, TLI)) { ++NumDeadInst; DEBUG(dbgs() << "IC: DCE: " << *Inst << '\n'); + salvageDebugInfo(*Inst); Inst->eraseFromParent(); MadeIRChange = true; continue; @@ -3162,12 +3223,11 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, return MadeIRChange; } -static bool -combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist, - AliasAnalysis *AA, AssumptionCache &AC, - TargetLibraryInfo &TLI, DominatorTree &DT, - bool ExpensiveCombines = true, - LoopInfo *LI = nullptr) { +static bool combineInstructionsOverFunction( + Function &F, InstCombineWorklist &Worklist, AliasAnalysis *AA, + AssumptionCache &AC, TargetLibraryInfo &TLI, DominatorTree &DT, + OptimizationRemarkEmitter &ORE, bool ExpensiveCombines = true, + LoopInfo *LI = nullptr) { auto &DL = F.getParent()->getDataLayout(); ExpensiveCombines |= EnableExpensiveCombines; @@ -3177,27 +3237,27 @@ combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist, F.getContext(), TargetFolder(DL), IRBuilderCallbackInserter([&Worklist, &AC](Instruction *I) { Worklist.Add(I); - - using namespace llvm::PatternMatch; if (match(I, m_Intrinsic<Intrinsic::assume>())) AC.registerAssumption(cast<CallInst>(I)); })); // Lower dbg.declare intrinsics otherwise their value may be clobbered // by instcombiner. - bool MadeIRChange = LowerDbgDeclare(F); + bool MadeIRChange = false; + if (ShouldLowerDbgDeclare) + MadeIRChange = LowerDbgDeclare(F); // Iterate while there is work to do. int Iteration = 0; - for (;;) { + while (true) { ++Iteration; DEBUG(dbgs() << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on " << F.getName() << "\n"); MadeIRChange |= prepareICWorklistFromFunction(F, DL, &TLI, Worklist); - InstCombiner IC(Worklist, Builder, F.optForMinSize(), ExpensiveCombines, - AA, AC, TLI, DT, DL, LI); + InstCombiner IC(Worklist, Builder, F.optForMinSize(), ExpensiveCombines, AA, + AC, TLI, DT, ORE, DL, LI); IC.MaxArraySizeForCombine = MaxArraySize; if (!IC.run()) @@ -3212,11 +3272,12 @@ PreservedAnalyses InstCombinePass::run(Function &F, auto &AC = AM.getResult<AssumptionAnalysis>(F); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); auto *LI = AM.getCachedResult<LoopAnalysis>(F); - // FIXME: The AliasAnalysis is not yet supported in the new pass manager - if (!combineInstructionsOverFunction(F, Worklist, nullptr, AC, TLI, DT, + auto *AA = &AM.getResult<AAManager>(F); + if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE, ExpensiveCombines, LI)) // No changes, all analyses are preserved. return PreservedAnalyses::all(); @@ -3225,6 +3286,7 @@ PreservedAnalyses InstCombinePass::run(Function &F, PreservedAnalyses PA; PA.preserveSet<CFGAnalyses>(); PA.preserve<AAManager>(); + PA.preserve<BasicAA>(); PA.preserve<GlobalsAA>(); return PA; } @@ -3235,6 +3297,7 @@ void InstructionCombiningPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); AU.addPreserved<AAResultsWrapperPass>(); AU.addPreserved<BasicAAWrapperPass>(); @@ -3250,16 +3313,18 @@ bool InstructionCombiningPass::runOnFunction(Function &F) { auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); // Optional analyses. auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; - return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, + return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE, ExpensiveCombines, LI); } char InstructionCombiningPass::ID = 0; + INITIALIZE_PASS_BEGIN(InstructionCombiningPass, "instcombine", "Combine redundant instructions", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) @@ -3267,6 +3332,7 @@ INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) INITIALIZE_PASS_END(InstructionCombiningPass, "instcombine", "Combine redundant instructions", false, false) |