diff options
Diffstat (limited to 'lib/Analysis/InstructionSimplify.cpp')
-rw-r--r-- | lib/Analysis/InstructionSimplify.cpp | 680 |
1 files changed, 387 insertions, 293 deletions
diff --git a/lib/Analysis/InstructionSimplify.cpp b/lib/Analysis/InstructionSimplify.cpp index 796e6e444980..e12f640394e6 100644 --- a/lib/Analysis/InstructionSimplify.cpp +++ b/lib/Analysis/InstructionSimplify.cpp @@ -24,6 +24,7 @@ #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/OptimizationDiagnosticInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" @@ -140,10 +141,9 @@ static bool ValueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { /// given by OpcodeToExpand, while "A" corresponds to LHS and "B op' C" to RHS. /// Also performs the transform "(A op' B) op C" -> "(A op C) op' (B op C)". /// Returns the simplified value, or null if no simplification was performed. -static Value *ExpandBinOp(unsigned Opcode, Value *LHS, Value *RHS, - unsigned OpcToExpand, const Query &Q, +static Value *ExpandBinOp(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS, + Instruction::BinaryOps OpcodeToExpand, const Query &Q, unsigned MaxRecurse) { - Instruction::BinaryOps OpcodeToExpand = (Instruction::BinaryOps)OpcToExpand; // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) return nullptr; @@ -199,9 +199,9 @@ static Value *ExpandBinOp(unsigned Opcode, Value *LHS, Value *RHS, /// Generic simplifications for associative binary operations. /// Returns the simpler value, or null if none was found. -static Value *SimplifyAssociativeBinOp(unsigned Opc, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse) { - Instruction::BinaryOps Opcode = (Instruction::BinaryOps)Opc; +static Value *SimplifyAssociativeBinOp(Instruction::BinaryOps Opcode, + Value *LHS, Value *RHS, const Query &Q, + unsigned MaxRecurse) { assert(Instruction::isAssociative(Opcode) && "Not an associative operation!"); // Recursion is always used, so bail out at once if we already hit the limit. @@ -298,8 +298,9 @@ static Value *SimplifyAssociativeBinOp(unsigned Opc, Value *LHS, Value *RHS, /// try to simplify the binop by seeing whether evaluating it on both branches /// of the select results in the same value. Returns the common value if so, /// otherwise returns null. -static Value *ThreadBinOpOverSelect(unsigned Opcode, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse) { +static Value *ThreadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, + Value *RHS, const Query &Q, + unsigned MaxRecurse) { // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) return nullptr; @@ -451,8 +452,9 @@ static Value *ThreadCmpOverSelect(CmpInst::Predicate Pred, Value *LHS, /// try to simplify the binop by seeing whether evaluating it on the incoming /// phi values yields the same result for every value. If so returns the common /// value, otherwise returns null. -static Value *ThreadBinOpOverPHI(unsigned Opcode, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse) { +static Value *ThreadBinOpOverPHI(Instruction::BinaryOps Opcode, Value *LHS, + Value *RHS, const Query &Q, + unsigned MaxRecurse) { // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) return nullptr; @@ -527,17 +529,26 @@ static Value *ThreadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS, return CommonValue; } +static Constant *foldOrCommuteConstant(Instruction::BinaryOps Opcode, + Value *&Op0, Value *&Op1, + const Query &Q) { + if (auto *CLHS = dyn_cast<Constant>(Op0)) { + if (auto *CRHS = dyn_cast<Constant>(Op1)) + return ConstantFoldBinaryOpOperands(Opcode, CLHS, CRHS, Q.DL); + + // Canonicalize the constant to the RHS if this is a commutative operation. + if (Instruction::isCommutative(Opcode)) + std::swap(Op0, Op1); + } + return nullptr; +} + /// Given operands for an Add, see if we can fold the result. /// If not, this returns null. static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Add, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::Add, Op0, Op1, Q)) + return C; // X + undef -> undef if (match(Op1, m_Undef())) @@ -556,12 +567,20 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, return Y; // X + ~X -> -1 since ~X = -X-1 + Type *Ty = Op0->getType(); if (match(Op0, m_Not(m_Specific(Op1))) || match(Op1, m_Not(m_Specific(Op0)))) - return Constant::getAllOnesValue(Op0->getType()); + return Constant::getAllOnesValue(Ty); + + // add nsw/nuw (xor Y, signbit), signbit --> Y + // The no-wrapping add guarantees that the top bit will be set by the add. + // Therefore, the xor must be clearing the already set sign bit of Y. + if ((isNSW || isNUW) && match(Op1, m_SignBit()) && + match(Op0, m_Xor(m_Value(Y), m_SignBit()))) + return Y; /// i1 add -> xor. - if (MaxRecurse && Op0->getType()->isIntegerTy(1)) + if (MaxRecurse && Op0->getType()->getScalarType()->isIntegerTy(1)) if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse-1)) return V; @@ -665,9 +684,8 @@ static Constant *computePointerDifference(const DataLayout &DL, Value *LHS, /// If not, this returns null. static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Sub, CLHS, CRHS, Q.DL); + if (Constant *C = foldOrCommuteConstant(Instruction::Sub, Op0, Op1, Q)) + return C; // X - undef -> undef // undef - X -> undef @@ -692,7 +710,7 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); computeKnownBits(Op1, KnownZero, KnownOne, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); - if (KnownZero == ~APInt::getSignBit(BitWidth)) { + if (KnownZero.isMaxSignedValue()) { // Op1 is either 0 or the minimum signed value. If the sub is NSW, then // Op1 must be 0 because negating the minimum signed value is undefined. if (isNSW) @@ -779,7 +797,7 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, return ConstantExpr::getIntegerCast(Result, Op0->getType(), true); // i1 sub -> xor. - if (MaxRecurse && Op0->getType()->isIntegerTy(1)) + if (MaxRecurse && Op0->getType()->getScalarType()->isIntegerTy(1)) if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse-1)) return V; @@ -807,13 +825,8 @@ Value *llvm::SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, /// returns null. static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::FAdd, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q)) + return C; // fadd X, -0 ==> X if (match(Op1, m_NegZero())) @@ -846,10 +859,8 @@ static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, /// returns null. static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::FSub, CLHS, CRHS, Q.DL); - } + if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) + return C; // fsub X, 0 ==> X if (match(Op1, m_Zero())) @@ -878,40 +889,28 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, } /// Given the operands for an FMul, see if we can fold the result -static Value *SimplifyFMulInst(Value *Op0, Value *Op1, - FastMathFlags FMF, - const Query &Q, - unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::FMul, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } +static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const Query &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q)) + return C; - // fmul X, 1.0 ==> X - if (match(Op1, m_FPOne())) - return Op0; + // fmul X, 1.0 ==> X + if (match(Op1, m_FPOne())) + return Op0; - // fmul nnan nsz X, 0 ==> 0 - if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op1, m_AnyZero())) - return Op1; + // fmul nnan nsz X, 0 ==> 0 + if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op1, m_AnyZero())) + return Op1; - return nullptr; + return nullptr; } /// Given operands for a Mul, see if we can fold the result. /// If not, this returns null. static Value *SimplifyMulInst(Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Mul, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::Mul, Op0, Op1, Q)) + return C; // X * undef -> 0 if (match(Op1, m_Undef())) @@ -932,7 +931,7 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const Query &Q, return X; // i1 mul -> and. - if (MaxRecurse && Op0->getType()->isIntegerTy(1)) + if (MaxRecurse && Op0->getType()->getScalarType()->isIntegerTy(1)) if (Value *V = SimplifyAndInst(Op0, Op1, Q, MaxRecurse-1)) return V; @@ -998,43 +997,68 @@ Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const DataLayout &DL, RecursionLimit); } -/// Given operands for an SDiv or UDiv, see if we can fold the result. -/// If not, this returns null. -static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, - const Query &Q, unsigned MaxRecurse) { - if (Constant *C0 = dyn_cast<Constant>(Op0)) - if (Constant *C1 = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL); - - bool isSigned = Opcode == Instruction::SDiv; +/// Check for common or similar folds of integer division or integer remainder. +static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) { + Type *Ty = Op0->getType(); // X / undef -> undef + // X % undef -> undef if (match(Op1, m_Undef())) return Op1; - // X / 0 -> undef, we don't need to preserve faults! + // X / 0 -> undef + // X % 0 -> undef + // We don't need to preserve faults! if (match(Op1, m_Zero())) - return UndefValue::get(Op1->getType()); + return UndefValue::get(Ty); + + // If any element of a constant divisor vector is zero, the whole op is undef. + auto *Op1C = dyn_cast<Constant>(Op1); + if (Op1C && Ty->isVectorTy()) { + unsigned NumElts = Ty->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = Op1C->getAggregateElement(i); + if (Elt && Elt->isNullValue()) + return UndefValue::get(Ty); + } + } // undef / X -> 0 + // undef % X -> 0 if (match(Op0, m_Undef())) - return Constant::getNullValue(Op0->getType()); + return Constant::getNullValue(Ty); - // 0 / X -> 0, we don't need to preserve faults! + // 0 / X -> 0 + // 0 % X -> 0 if (match(Op0, m_Zero())) return Op0; + // X / X -> 1 + // X % X -> 0 + if (Op0 == Op1) + return IsDiv ? ConstantInt::get(Ty, 1) : Constant::getNullValue(Ty); + // X / 1 -> X - if (match(Op1, m_One())) - return Op0; + // X % 1 -> 0 + // If this is a boolean op (single-bit element type), we can't have + // division-by-zero or remainder-by-zero, so assume the divisor is 1. + if (match(Op1, m_One()) || Ty->getScalarType()->isIntegerTy(1)) + return IsDiv ? Op0 : Constant::getNullValue(Ty); - if (Op0->getType()->isIntegerTy(1)) - // It can't be division by zero, hence it must be division by one. - return Op0; + return nullptr; +} - // X / X -> 1 - if (Op0 == Op1) - return ConstantInt::get(Op0->getType(), 1); +/// Given operands for an SDiv or UDiv, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, + const Query &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) + return C; + + if (Value *V = simplifyDivRem(Op0, Op1, true)) + return V; + + bool isSigned = Opcode == Instruction::SDiv; // (X * Y) / Y -> X if the multiplication does not overflow. Value *X = nullptr, *Y = nullptr; @@ -1129,6 +1153,9 @@ Value *llvm::SimplifyUDivInst(Value *Op0, Value *Op1, const DataLayout &DL, static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, const Query &Q, unsigned) { + if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q)) + return C; + // undef / X -> undef (the undef could be a snan). if (match(Op0, m_Undef())) return Op0; @@ -1178,37 +1205,11 @@ Value *llvm::SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, /// If not, this returns null. static Value *SimplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { - if (Constant *C0 = dyn_cast<Constant>(Op0)) - if (Constant *C1 = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL); - - // X % undef -> undef - if (match(Op1, m_Undef())) - return Op1; - - // undef % X -> 0 - if (match(Op0, m_Undef())) - return Constant::getNullValue(Op0->getType()); - - // 0 % X -> 0, we don't need to preserve faults! - if (match(Op0, m_Zero())) - return Op0; + if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) + return C; - // X % 0 -> undef, we don't need to preserve faults! - if (match(Op1, m_Zero())) - return UndefValue::get(Op0->getType()); - - // X % 1 -> 0 - if (match(Op1, m_One())) - return Constant::getNullValue(Op0->getType()); - - if (Op0->getType()->isIntegerTy(1)) - // It can't be remainder by zero, hence it must be remainder by one. - return Constant::getNullValue(Op0->getType()); - - // X % X -> 0 - if (Op0 == Op1) - return Constant::getNullValue(Op0->getType()); + if (Value *V = simplifyDivRem(Op0, Op1, false)) + return V; // (X % Y) % Y -> X % Y if ((Opcode == Instruction::SRem && @@ -1279,7 +1280,10 @@ Value *llvm::SimplifyURemInst(Value *Op0, Value *Op1, const DataLayout &DL, } static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const Query &, unsigned) { + const Query &Q, unsigned) { + if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q)) + return C; + // undef % X -> undef (the undef could be a snan). if (match(Op0, m_Undef())) return Op0; @@ -1335,11 +1339,10 @@ static bool isUndefShift(Value *Amount) { /// Given operands for an Shl, LShr or AShr, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1, - const Query &Q, unsigned MaxRecurse) { - if (Constant *C0 = dyn_cast<Constant>(Op0)) - if (Constant *C1 = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL); +static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, + Value *Op1, const Query &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) + return C; // 0 shift by X -> 0 if (match(Op0, m_Zero())) @@ -1386,8 +1389,8 @@ static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1, /// \brief Given operands for an Shl, LShr or AShr, see if we can /// fold the result. If not, this returns null. -static Value *SimplifyRightShift(unsigned Opcode, Value *Op0, Value *Op1, - bool isExact, const Query &Q, +static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, + Value *Op1, bool isExact, const Query &Q, unsigned MaxRecurse) { if (Value *V = SimplifyShift(Opcode, Op0, Op1, Q, MaxRecurse)) return V; @@ -1636,13 +1639,8 @@ static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { /// If not, this returns null. static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::And, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::And, Op0, Op1, Q)) + return C; // X & undef -> 0 if (match(Op1, m_Undef())) @@ -1838,13 +1836,8 @@ static Value *SimplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { /// If not, this returns null. static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Or, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::Or, Op0, Op1, Q)) + return C; // X | undef -> -1 if (match(Op1, m_Undef())) @@ -1971,13 +1964,8 @@ Value *llvm::SimplifyOrInst(Value *Op0, Value *Op1, const DataLayout &DL, /// If not, this returns null. static Value *SimplifyXorInst(Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Xor, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::Xor, Op0, Op1, Q)) + return C; // A ^ undef -> undef if (match(Op1, m_Undef())) @@ -2377,6 +2365,163 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, return nullptr; } +/// Many binary operators with a constant operand have an easy-to-compute +/// range of outputs. This can be used to fold a comparison to always true or +/// always false. +static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper) { + unsigned Width = Lower.getBitWidth(); + const APInt *C; + switch (BO.getOpcode()) { + case Instruction::Add: + if (match(BO.getOperand(1), m_APInt(C)) && *C != 0) { + // FIXME: If we have both nuw and nsw, we should reduce the range further. + if (BO.hasNoUnsignedWrap()) { + // 'add nuw x, C' produces [C, UINT_MAX]. + Lower = *C; + } else if (BO.hasNoSignedWrap()) { + if (C->isNegative()) { + // 'add nsw x, -C' produces [SINT_MIN, SINT_MAX - C]. + Lower = APInt::getSignedMinValue(Width); + Upper = APInt::getSignedMaxValue(Width) + *C + 1; + } else { + // 'add nsw x, +C' produces [SINT_MIN + C, SINT_MAX]. + Lower = APInt::getSignedMinValue(Width) + *C; + Upper = APInt::getSignedMaxValue(Width) + 1; + } + } + } + break; + + case Instruction::And: + if (match(BO.getOperand(1), m_APInt(C))) + // 'and x, C' produces [0, C]. + Upper = *C + 1; + break; + + case Instruction::Or: + if (match(BO.getOperand(1), m_APInt(C))) + // 'or x, C' produces [C, UINT_MAX]. + Lower = *C; + break; + + case Instruction::AShr: + if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) { + // 'ashr x, C' produces [INT_MIN >> C, INT_MAX >> C]. + Lower = APInt::getSignedMinValue(Width).ashr(*C); + Upper = APInt::getSignedMaxValue(Width).ashr(*C) + 1; + } else if (match(BO.getOperand(0), m_APInt(C))) { + unsigned ShiftAmount = Width - 1; + if (*C != 0 && BO.isExact()) + ShiftAmount = C->countTrailingZeros(); + if (C->isNegative()) { + // 'ashr C, x' produces [C, C >> (Width-1)] + Lower = *C; + Upper = C->ashr(ShiftAmount) + 1; + } else { + // 'ashr C, x' produces [C >> (Width-1), C] + Lower = C->ashr(ShiftAmount); + Upper = *C + 1; + } + } + break; + + case Instruction::LShr: + if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) { + // 'lshr x, C' produces [0, UINT_MAX >> C]. + Upper = APInt::getAllOnesValue(Width).lshr(*C) + 1; + } else if (match(BO.getOperand(0), m_APInt(C))) { + // 'lshr C, x' produces [C >> (Width-1), C]. + unsigned ShiftAmount = Width - 1; + if (*C != 0 && BO.isExact()) + ShiftAmount = C->countTrailingZeros(); + Lower = C->lshr(ShiftAmount); + Upper = *C + 1; + } + break; + + case Instruction::Shl: + if (match(BO.getOperand(0), m_APInt(C))) { + if (BO.hasNoUnsignedWrap()) { + // 'shl nuw C, x' produces [C, C << CLZ(C)] + Lower = *C; + Upper = Lower.shl(Lower.countLeadingZeros()) + 1; + } else if (BO.hasNoSignedWrap()) { // TODO: What if both nuw+nsw? + if (C->isNegative()) { + // 'shl nsw C, x' produces [C << CLO(C)-1, C] + unsigned ShiftAmount = C->countLeadingOnes() - 1; + Lower = C->shl(ShiftAmount); + Upper = *C + 1; + } else { + // 'shl nsw C, x' produces [C, C << CLZ(C)-1] + unsigned ShiftAmount = C->countLeadingZeros() - 1; + Lower = *C; + Upper = C->shl(ShiftAmount) + 1; + } + } + } + break; + + case Instruction::SDiv: + if (match(BO.getOperand(1), m_APInt(C))) { + APInt IntMin = APInt::getSignedMinValue(Width); + APInt IntMax = APInt::getSignedMaxValue(Width); + if (C->isAllOnesValue()) { + // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX] + // where C != -1 and C != 0 and C != 1 + Lower = IntMin + 1; + Upper = IntMax + 1; + } else if (C->countLeadingZeros() < Width - 1) { + // 'sdiv x, C' produces [INT_MIN / C, INT_MAX / C] + // where C != -1 and C != 0 and C != 1 + Lower = IntMin.sdiv(*C); + Upper = IntMax.sdiv(*C); + if (Lower.sgt(Upper)) + std::swap(Lower, Upper); + Upper = Upper + 1; + assert(Upper != Lower && "Upper part of range has wrapped!"); + } + } else if (match(BO.getOperand(0), m_APInt(C))) { + if (C->isMinSignedValue()) { + // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2]. + Lower = *C; + Upper = Lower.lshr(1) + 1; + } else { + // 'sdiv C, x' produces [-|C|, |C|]. + Upper = C->abs() + 1; + Lower = (-Upper) + 1; + } + } + break; + + case Instruction::UDiv: + if (match(BO.getOperand(1), m_APInt(C)) && *C != 0) { + // 'udiv x, C' produces [0, UINT_MAX / C]. + Upper = APInt::getMaxValue(Width).udiv(*C) + 1; + } else if (match(BO.getOperand(0), m_APInt(C))) { + // 'udiv C, x' produces [0, C]. + Upper = *C + 1; + } + break; + + case Instruction::SRem: + if (match(BO.getOperand(1), m_APInt(C))) { + // 'srem x, C' produces (-|C|, |C|). + Upper = C->abs(); + Lower = (-Upper) + 1; + } + break; + + case Instruction::URem: + if (match(BO.getOperand(1), m_APInt(C))) + // 'urem x, C' produces [0, C). + Upper = *C; + break; + + default: + break; + } +} + static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, Value *RHS) { const APInt *C; @@ -2390,114 +2535,12 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, if (RHS_CR.isFullSet()) return ConstantInt::getTrue(GetCompareTy(RHS)); - // Many binary operators with constant RHS have easy to compute constant - // range. Use them to check whether the comparison is a tautology. + // Find the range of possible values for binary operators. unsigned Width = C->getBitWidth(); APInt Lower = APInt(Width, 0); APInt Upper = APInt(Width, 0); - const APInt *C2; - if (match(LHS, m_URem(m_Value(), m_APInt(C2)))) { - // 'urem x, C2' produces [0, C2). - Upper = *C2; - } else if (match(LHS, m_SRem(m_Value(), m_APInt(C2)))) { - // 'srem x, C2' produces (-|C2|, |C2|). - Upper = C2->abs(); - Lower = (-Upper) + 1; - } else if (match(LHS, m_UDiv(m_APInt(C2), m_Value()))) { - // 'udiv C2, x' produces [0, C2]. - Upper = *C2 + 1; - } else if (match(LHS, m_UDiv(m_Value(), m_APInt(C2)))) { - // 'udiv x, C2' produces [0, UINT_MAX / C2]. - APInt NegOne = APInt::getAllOnesValue(Width); - if (*C2 != 0) - Upper = NegOne.udiv(*C2) + 1; - } else if (match(LHS, m_SDiv(m_APInt(C2), m_Value()))) { - if (C2->isMinSignedValue()) { - // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2]. - Lower = *C2; - Upper = Lower.lshr(1) + 1; - } else { - // 'sdiv C2, x' produces [-|C2|, |C2|]. - Upper = C2->abs() + 1; - Lower = (-Upper) + 1; - } - } else if (match(LHS, m_SDiv(m_Value(), m_APInt(C2)))) { - APInt IntMin = APInt::getSignedMinValue(Width); - APInt IntMax = APInt::getSignedMaxValue(Width); - if (C2->isAllOnesValue()) { - // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX] - // where C2 != -1 and C2 != 0 and C2 != 1 - Lower = IntMin + 1; - Upper = IntMax + 1; - } else if (C2->countLeadingZeros() < Width - 1) { - // 'sdiv x, C2' produces [INT_MIN / C2, INT_MAX / C2] - // where C2 != -1 and C2 != 0 and C2 != 1 - Lower = IntMin.sdiv(*C2); - Upper = IntMax.sdiv(*C2); - if (Lower.sgt(Upper)) - std::swap(Lower, Upper); - Upper = Upper + 1; - assert(Upper != Lower && "Upper part of range has wrapped!"); - } - } else if (match(LHS, m_NUWShl(m_APInt(C2), m_Value()))) { - // 'shl nuw C2, x' produces [C2, C2 << CLZ(C2)] - Lower = *C2; - Upper = Lower.shl(Lower.countLeadingZeros()) + 1; - } else if (match(LHS, m_NSWShl(m_APInt(C2), m_Value()))) { - if (C2->isNegative()) { - // 'shl nsw C2, x' produces [C2 << CLO(C2)-1, C2] - unsigned ShiftAmount = C2->countLeadingOnes() - 1; - Lower = C2->shl(ShiftAmount); - Upper = *C2 + 1; - } else { - // 'shl nsw C2, x' produces [C2, C2 << CLZ(C2)-1] - unsigned ShiftAmount = C2->countLeadingZeros() - 1; - Lower = *C2; - Upper = C2->shl(ShiftAmount) + 1; - } - } else if (match(LHS, m_LShr(m_Value(), m_APInt(C2)))) { - // 'lshr x, C2' produces [0, UINT_MAX >> C2]. - APInt NegOne = APInt::getAllOnesValue(Width); - if (C2->ult(Width)) - Upper = NegOne.lshr(*C2) + 1; - } else if (match(LHS, m_LShr(m_APInt(C2), m_Value()))) { - // 'lshr C2, x' produces [C2 >> (Width-1), C2]. - unsigned ShiftAmount = Width - 1; - if (*C2 != 0 && cast<BinaryOperator>(LHS)->isExact()) - ShiftAmount = C2->countTrailingZeros(); - Lower = C2->lshr(ShiftAmount); - Upper = *C2 + 1; - } else if (match(LHS, m_AShr(m_Value(), m_APInt(C2)))) { - // 'ashr x, C2' produces [INT_MIN >> C2, INT_MAX >> C2]. - APInt IntMin = APInt::getSignedMinValue(Width); - APInt IntMax = APInt::getSignedMaxValue(Width); - if (C2->ult(Width)) { - Lower = IntMin.ashr(*C2); - Upper = IntMax.ashr(*C2) + 1; - } - } else if (match(LHS, m_AShr(m_APInt(C2), m_Value()))) { - unsigned ShiftAmount = Width - 1; - if (*C2 != 0 && cast<BinaryOperator>(LHS)->isExact()) - ShiftAmount = C2->countTrailingZeros(); - if (C2->isNegative()) { - // 'ashr C2, x' produces [C2, C2 >> (Width-1)] - Lower = *C2; - Upper = C2->ashr(ShiftAmount) + 1; - } else { - // 'ashr C2, x' produces [C2 >> (Width-1), C2] - Lower = C2->ashr(ShiftAmount); - Upper = *C2 + 1; - } - } else if (match(LHS, m_Or(m_Value(), m_APInt(C2)))) { - // 'or x, C2' produces [C2, UINT_MAX]. - Lower = *C2; - } else if (match(LHS, m_And(m_Value(), m_APInt(C2)))) { - // 'and x, C2' produces [0, C2]. - Upper = *C2 + 1; - } else if (match(LHS, m_NUWAdd(m_Value(), m_APInt(C2)))) { - // 'add nuw x, C2' produces [C2, UINT_MAX]. - Lower = *C2; - } + if (auto *BO = dyn_cast<BinaryOperator>(LHS)) + setLimitsForBinOp(*BO, Lower, Upper); ConstantRange LHS_CR = Lower != Upper ? ConstantRange(Lower, Upper) : ConstantRange(Width, true); @@ -3064,8 +3107,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // If both operands have range metadata, use the metadata // to simplify the comparison. if (isa<Instruction>(RHS) && isa<Instruction>(LHS)) { - auto RHS_Instr = dyn_cast<Instruction>(RHS); - auto LHS_Instr = dyn_cast<Instruction>(LHS); + auto RHS_Instr = cast<Instruction>(RHS); + auto LHS_Instr = cast<Instruction>(LHS); if (RHS_Instr->getMetadata(LLVMContext::MD_range) && LHS_Instr->getMetadata(LLVMContext::MD_range)) { @@ -4039,6 +4082,62 @@ Value *llvm::SimplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty, RecursionLimit); } +static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask, + Type *RetTy, const Query &Q, + unsigned MaxRecurse) { + Type *InVecTy = Op0->getType(); + unsigned MaskNumElts = Mask->getType()->getVectorNumElements(); + unsigned InVecNumElts = InVecTy->getVectorNumElements(); + + auto *Op0Const = dyn_cast<Constant>(Op0); + auto *Op1Const = dyn_cast<Constant>(Op1); + + // If all operands are constant, constant fold the shuffle. + if (Op0Const && Op1Const) + return ConstantFoldShuffleVectorInstruction(Op0Const, Op1Const, Mask); + + // If only one of the operands is constant, constant fold the shuffle if the + // mask does not select elements from the variable operand. + bool MaskSelects0 = false, MaskSelects1 = false; + for (unsigned i = 0; i != MaskNumElts; ++i) { + int Idx = ShuffleVectorInst::getMaskValue(Mask, i); + if (Idx == -1) + continue; + if ((unsigned)Idx < InVecNumElts) + MaskSelects0 = true; + else + MaskSelects1 = true; + } + if (!MaskSelects0 && Op1Const) + return ConstantFoldShuffleVectorInstruction(UndefValue::get(InVecTy), + Op1Const, Mask); + if (!MaskSelects1 && Op0Const) + return ConstantFoldShuffleVectorInstruction(Op0Const, + UndefValue::get(InVecTy), Mask); + + // A shuffle of a splat is always the splat itself. Legal if the shuffle's + // value type is same as the input vectors' type. + if (auto *OpShuf = dyn_cast<ShuffleVectorInst>(Op0)) + if (!MaskSelects1 && RetTy == InVecTy && + OpShuf->getMask()->getSplatValue()) + return Op0; + if (auto *OpShuf = dyn_cast<ShuffleVectorInst>(Op1)) + if (!MaskSelects0 && RetTy == InVecTy && + OpShuf->getMask()->getSplatValue()) + return Op1; + + return nullptr; +} + +/// Given operands for a ShuffleVectorInst, fold the result or return null. +Value *llvm::SimplifyShuffleVectorInst( + Value *Op0, Value *Op1, Constant *Mask, Type *RetTy, + const DataLayout &DL, const TargetLibraryInfo *TLI, const DominatorTree *DT, + AssumptionCache *AC, const Instruction *CxtI) { + return ::SimplifyShuffleVectorInst( + Op0, Op1, Mask, RetTy, Query(DL, TLI, DT, AC, CxtI), RecursionLimit); +} + //=== Helper functions for higher up the class hierarchy. /// Given operands for a BinaryOperator, see if we can fold the result. @@ -4047,61 +4146,43 @@ static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, const Query &Q, unsigned MaxRecurse) { switch (Opcode) { case Instruction::Add: - return SimplifyAddInst(LHS, RHS, /*isNSW*/false, /*isNUW*/false, - Q, MaxRecurse); + return SimplifyAddInst(LHS, RHS, false, false, Q, MaxRecurse); case Instruction::FAdd: return SimplifyFAddInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); - case Instruction::Sub: - return SimplifySubInst(LHS, RHS, /*isNSW*/false, /*isNUW*/false, - Q, MaxRecurse); + return SimplifySubInst(LHS, RHS, false, false, Q, MaxRecurse); case Instruction::FSub: return SimplifyFSubInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); - - case Instruction::Mul: return SimplifyMulInst (LHS, RHS, Q, MaxRecurse); + case Instruction::Mul: + return SimplifyMulInst(LHS, RHS, Q, MaxRecurse); case Instruction::FMul: - return SimplifyFMulInst (LHS, RHS, FastMathFlags(), Q, MaxRecurse); - case Instruction::SDiv: return SimplifySDivInst(LHS, RHS, Q, MaxRecurse); - case Instruction::UDiv: return SimplifyUDivInst(LHS, RHS, Q, MaxRecurse); + return SimplifyFMulInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + case Instruction::SDiv: + return SimplifySDivInst(LHS, RHS, Q, MaxRecurse); + case Instruction::UDiv: + return SimplifyUDivInst(LHS, RHS, Q, MaxRecurse); case Instruction::FDiv: - return SimplifyFDivInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); - case Instruction::SRem: return SimplifySRemInst(LHS, RHS, Q, MaxRecurse); - case Instruction::URem: return SimplifyURemInst(LHS, RHS, Q, MaxRecurse); + return SimplifyFDivInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + case Instruction::SRem: + return SimplifySRemInst(LHS, RHS, Q, MaxRecurse); + case Instruction::URem: + return SimplifyURemInst(LHS, RHS, Q, MaxRecurse); case Instruction::FRem: - return SimplifyFRemInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + return SimplifyFRemInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); case Instruction::Shl: - return SimplifyShlInst(LHS, RHS, /*isNSW*/false, /*isNUW*/false, - Q, MaxRecurse); + return SimplifyShlInst(LHS, RHS, false, false, Q, MaxRecurse); case Instruction::LShr: - return SimplifyLShrInst(LHS, RHS, /*isExact*/false, Q, MaxRecurse); + return SimplifyLShrInst(LHS, RHS, false, Q, MaxRecurse); case Instruction::AShr: - return SimplifyAShrInst(LHS, RHS, /*isExact*/false, Q, MaxRecurse); - case Instruction::And: return SimplifyAndInst(LHS, RHS, Q, MaxRecurse); - case Instruction::Or: return SimplifyOrInst (LHS, RHS, Q, MaxRecurse); - case Instruction::Xor: return SimplifyXorInst(LHS, RHS, Q, MaxRecurse); + return SimplifyAShrInst(LHS, RHS, false, Q, MaxRecurse); + case Instruction::And: + return SimplifyAndInst(LHS, RHS, Q, MaxRecurse); + case Instruction::Or: + return SimplifyOrInst(LHS, RHS, Q, MaxRecurse); + case Instruction::Xor: + return SimplifyXorInst(LHS, RHS, Q, MaxRecurse); default: - if (Constant *CLHS = dyn_cast<Constant>(LHS)) - if (Constant *CRHS = dyn_cast<Constant>(RHS)) - return ConstantFoldBinaryOpOperands(Opcode, CLHS, CRHS, Q.DL); - - // If the operation is associative, try some generic simplifications. - if (Instruction::isAssociative(Opcode)) - if (Value *V = SimplifyAssociativeBinOp(Opcode, LHS, RHS, Q, MaxRecurse)) - return V; - - // If the operation is with the result of a select instruction check whether - // operating on either branch of the select always yields the same value. - if (isa<SelectInst>(LHS) || isa<SelectInst>(RHS)) - if (Value *V = ThreadBinOpOverSelect(Opcode, LHS, RHS, Q, MaxRecurse)) - return V; - - // If the operation is with the result of a phi instruction, check whether - // operating on all incoming values of the phi always yields the same value. - if (isa<PHINode>(LHS) || isa<PHINode>(RHS)) - if (Value *V = ThreadBinOpOverPHI(Opcode, LHS, RHS, Q, MaxRecurse)) - return V; - - return nullptr; + llvm_unreachable("Unexpected opcode"); } } @@ -4267,6 +4348,7 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, case Intrinsic::fabs: { if (SignBitMustBeZero(*ArgBegin, Q.TLI)) return *ArgBegin; + return nullptr; } default: return nullptr; @@ -4396,7 +4478,8 @@ Value *llvm::SimplifyCall(Value *V, ArrayRef<Value *> Args, /// If not, this returns null. Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout &DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC) { + const DominatorTree *DT, AssumptionCache *AC, + OptimizationRemarkEmitter *ORE) { Value *Result; switch (I->getOpcode()) { @@ -4522,6 +4605,13 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout &DL, EEI->getVectorOperand(), EEI->getIndexOperand(), DL, TLI, DT, AC, I); break; } + case Instruction::ShuffleVector: { + auto *SVI = cast<ShuffleVectorInst>(I); + Result = SimplifyShuffleVectorInst(SVI->getOperand(0), SVI->getOperand(1), + SVI->getMask(), SVI->getType(), DL, TLI, + DT, AC, I); + break; + } case Instruction::PHI: Result = SimplifyPHINode(cast<PHINode>(I), Query(DL, TLI, DT, AC, I)); break; @@ -4537,6 +4627,10 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout &DL, Result = SimplifyCastInst(I->getOpcode(), I->getOperand(0), I->getType(), DL, TLI, DT, AC, I); break; + case Instruction::Alloca: + // No simplifications for Alloca and it can't be constant folded. + Result = nullptr; + break; } // In general, it is possible for computeKnownBits to determine all bits in a @@ -4545,7 +4639,7 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout &DL, unsigned BitWidth = I->getType()->getScalarSizeInBits(); APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); - computeKnownBits(I, KnownZero, KnownOne, DL, /*Depth*/0, AC, I, DT); + computeKnownBits(I, KnownZero, KnownOne, DL, /*Depth*/0, AC, I, DT, ORE); if ((KnownZero | KnownOne).isAllOnesValue()) Result = ConstantInt::get(I->getType(), KnownOne); } |