diff options
Diffstat (limited to 'lib/Analysis/InstructionSimplify.cpp')
-rw-r--r-- | lib/Analysis/InstructionSimplify.cpp | 506 |
1 files changed, 336 insertions, 170 deletions
diff --git a/lib/Analysis/InstructionSimplify.cpp b/lib/Analysis/InstructionSimplify.cpp index 93fb1143e505..519d6d67be51 100644 --- a/lib/Analysis/InstructionSimplify.cpp +++ b/lib/Analysis/InstructionSimplify.cpp @@ -62,6 +62,8 @@ static Value *SimplifyOrInst(Value *, Value *, const SimplifyQuery &, unsigned); static Value *SimplifyXorInst(Value *, Value *, const SimplifyQuery &, unsigned); static Value *SimplifyCastInst(unsigned, Value *, Type *, const SimplifyQuery &, unsigned); +static Value *SimplifyGEPInst(Type *, ArrayRef<Value *>, const SimplifyQuery &, + unsigned); /// For a boolean type or a vector of boolean type, return false or a vector /// with every element false. @@ -90,7 +92,7 @@ static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS, } /// Does the given value dominate the specified phi node? -static bool ValueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { +static bool valueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { Instruction *I = dyn_cast<Instruction>(V); if (!I) // Arguments and constants dominate all instructions. @@ -99,7 +101,7 @@ static bool ValueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { // If we are processing instructions (and/or basic blocks) that have not been // fully added to a function, the parent nodes may still be null. Simply // return the conservative answer in these cases. - if (!I->getParent() || !P->getParent() || !I->getParent()->getParent()) + if (!I->getParent() || !P->getParent() || !I->getFunction()) return false; // If we have a DominatorTree then do a precise test. @@ -108,7 +110,7 @@ static bool ValueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { // Otherwise, if the instruction is in the entry block and is not an invoke, // then it obviously dominates all phi nodes. - if (I->getParent() == &I->getParent()->getParent()->getEntryBlock() && + if (I->getParent() == &I->getFunction()->getEntryBlock() && !isa<InvokeInst>(I)) return true; @@ -443,13 +445,13 @@ static Value *ThreadBinOpOverPHI(Instruction::BinaryOps Opcode, Value *LHS, if (isa<PHINode>(LHS)) { PI = cast<PHINode>(LHS); // Bail out if RHS and the phi may be mutually interdependent due to a loop. - if (!ValueDominatesPHI(RHS, PI, Q.DT)) + if (!valueDominatesPHI(RHS, PI, Q.DT)) return nullptr; } else { assert(isa<PHINode>(RHS) && "No PHI instruction operand!"); PI = cast<PHINode>(RHS); // Bail out if LHS and the phi may be mutually interdependent due to a loop. - if (!ValueDominatesPHI(LHS, PI, Q.DT)) + if (!valueDominatesPHI(LHS, PI, Q.DT)) return nullptr; } @@ -490,7 +492,7 @@ static Value *ThreadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS, PHINode *PI = cast<PHINode>(LHS); // Bail out if RHS and the phi may be mutually interdependent due to a loop. - if (!ValueDominatesPHI(RHS, PI, Q.DT)) + if (!valueDominatesPHI(RHS, PI, Q.DT)) return nullptr; // Evaluate the BinOp on the incoming phi values. @@ -525,7 +527,7 @@ static Constant *foldOrCommuteConstant(Instruction::BinaryOps Opcode, /// Given operands for an Add, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, +static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Constant *C = foldOrCommuteConstant(Instruction::Add, Op0, Op1, Q)) return C; @@ -538,6 +540,10 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, if (match(Op1, m_Zero())) return Op0; + // If two operands are negative, return 0. + if (isKnownNegation(Op0, Op1)) + return Constant::getNullValue(Op0->getType()); + // X + (Y - X) -> Y // (Y - X) + X -> Y // Eg: X + -X -> 0 @@ -555,10 +561,14 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, // add nsw/nuw (xor Y, signmask), signmask --> Y // The no-wrapping add guarantees that the top bit will be set by the add. // Therefore, the xor must be clearing the already set sign bit of Y. - if ((isNSW || isNUW) && match(Op1, m_SignMask()) && + if ((IsNSW || IsNUW) && match(Op1, m_SignMask()) && match(Op0, m_Xor(m_Value(Y), m_SignMask()))) return Y; + // add nuw %x, -1 -> -1, because %x can only be 0. + if (IsNUW && match(Op1, m_AllOnes())) + return Op1; // Which is -1. + /// i1 add -> xor. if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1)) if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse-1)) @@ -581,12 +591,12 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, return nullptr; } -Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, +Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, const SimplifyQuery &Query) { - return ::SimplifyAddInst(Op0, Op1, isNSW, isNUW, Query, RecursionLimit); + return ::SimplifyAddInst(Op0, Op1, IsNSW, IsNUW, Query, RecursionLimit); } -/// \brief Compute the base pointer and cumulative constant offsets for V. +/// Compute the base pointer and cumulative constant offsets for V. /// /// This strips all constant offsets off of V, leaving it the base pointer, and /// accumulates the total constant offset applied in the returned constant. It @@ -637,7 +647,7 @@ static Constant *stripAndComputeConstantOffsets(const DataLayout &DL, Value *&V, return OffsetIntPtr; } -/// \brief Compute the constant difference between two pointer values. +/// Compute the constant difference between two pointer values. /// If the difference is not a constant, returns zero. static Constant *computePointerDifference(const DataLayout &DL, Value *LHS, Value *RHS) { @@ -680,14 +690,14 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, if (match(Op0, m_Zero())) { // 0 - X -> 0 if the sub is NUW. if (isNUW) - return Op0; + return Constant::getNullValue(Op0->getType()); KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); if (Known.Zero.isMaxSignedValue()) { // Op1 is either 0 or the minimum signed value. If the sub is NSW, then // Op1 must be 0 because negating the minimum signed value is undefined. if (isNSW) - return Op0; + return Constant::getNullValue(Op0->getType()); // 0 - X -> X if X is 0 or the minimum signed value. return Op1; @@ -799,12 +809,9 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return C; // X * undef -> 0 - if (match(Op1, m_Undef())) - return Constant::getNullValue(Op0->getType()); - // X * 0 -> 0 - if (match(Op1, m_Zero())) - return Op1; + if (match(Op1, m_CombineOr(m_Undef(), m_Zero()))) + return Constant::getNullValue(Op0->getType()); // X * 1 -> X if (match(Op1, m_One())) @@ -826,7 +833,7 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, MaxRecurse)) return V; - // Mul distributes over Add. Try some generic simplifications based on this. + // Mul distributes over Add. Try some generic simplifications based on this. if (Value *V = ExpandBinOp(Instruction::Mul, Op0, Op1, Instruction::Add, Q, MaxRecurse)) return V; @@ -868,13 +875,14 @@ static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) { if (match(Op1, m_Zero())) return UndefValue::get(Ty); - // If any element of a constant divisor vector is zero, the whole op is undef. + // If any element of a constant divisor vector is zero or undef, the whole op + // is undef. auto *Op1C = dyn_cast<Constant>(Op1); if (Op1C && Ty->isVectorTy()) { unsigned NumElts = Ty->getVectorNumElements(); for (unsigned i = 0; i != NumElts; ++i) { Constant *Elt = Op1C->getAggregateElement(i); - if (Elt && Elt->isNullValue()) + if (Elt && (Elt->isNullValue() || isa<UndefValue>(Elt))) return UndefValue::get(Ty); } } @@ -887,7 +895,7 @@ static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) { // 0 / X -> 0 // 0 % X -> 0 if (match(Op0, m_Zero())) - return Op0; + return Constant::getNullValue(Op0->getType()); // X / X -> 1 // X % X -> 0 @@ -898,7 +906,10 @@ static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) { // X % 1 -> 0 // If this is a boolean op (single-bit element type), we can't have // division-by-zero or remainder-by-zero, so assume the divisor is 1. - if (match(Op1, m_One()) || Ty->isIntOrIntVectorTy(1)) + // Similarly, if we're zero-extending a boolean divisor, then assume it's a 1. + Value *X; + if (match(Op1, m_One()) || Ty->isIntOrIntVectorTy(1) || + (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))) return IsDiv ? Op0 : Constant::getNullValue(Ty); return nullptr; @@ -978,18 +989,17 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, bool IsSigned = Opcode == Instruction::SDiv; // (X * Y) / Y -> X if the multiplication does not overflow. - Value *X = nullptr, *Y = nullptr; - if (match(Op0, m_Mul(m_Value(X), m_Value(Y))) && (X == Op1 || Y == Op1)) { - if (Y != Op1) std::swap(X, Y); // Ensure expression is (X * Y) / Y, Y = Op1 - OverflowingBinaryOperator *Mul = cast<OverflowingBinaryOperator>(Op0); - // If the Mul knows it does not overflow, then we are good to go. + Value *X; + if (match(Op0, m_c_Mul(m_Value(X), m_Specific(Op1)))) { + auto *Mul = cast<OverflowingBinaryOperator>(Op0); + // If the Mul does not overflow, then we are good to go. if ((IsSigned && Mul->hasNoSignedWrap()) || (!IsSigned && Mul->hasNoUnsignedWrap())) return X; - // If X has the form X = A / Y then X * Y cannot overflow. - if (BinaryOperator *Div = dyn_cast<BinaryOperator>(X)) - if (Div->getOpcode() == Opcode && Div->getOperand(1) == Y) - return X; + // If X has the form X = A / Y, then X * Y cannot overflow. + if ((IsSigned && match(X, m_SDiv(m_Value(), m_Specific(Op1)))) || + (!IsSigned && match(X, m_UDiv(m_Value(), m_Specific(Op1))))) + return X; } // (X rem Y) / Y -> 0 @@ -1041,6 +1051,13 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, match(Op0, m_URem(m_Value(), m_Specific(Op1))))) return Op0; + // (X << Y) % X -> 0 + if ((Opcode == Instruction::SRem && + match(Op0, m_NSWShl(m_Specific(Op1), m_Value()))) || + (Opcode == Instruction::URem && + match(Op0, m_NUWShl(m_Specific(Op1), m_Value())))) + return Constant::getNullValue(Op0->getType()); + // If the operation is with the result of a select instruction, check whether // operating on either branch of the select always yields the same value. if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) @@ -1064,6 +1081,10 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, /// If not, this returns null. static Value *SimplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { + // If two operands are negated and no signed overflow, return -1. + if (isKnownNegation(Op0, Op1, /*NeedNSW=*/true)) + return Constant::getAllOnesValue(Op0->getType()); + return simplifyDiv(Instruction::SDiv, Op0, Op1, Q, MaxRecurse); } @@ -1086,6 +1107,16 @@ Value *llvm::SimplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { /// If not, this returns null. static Value *SimplifySRemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { + // If the divisor is 0, the result is undefined, so assume the divisor is -1. + // srem Op0, (sext i1 X) --> srem Op0, -1 --> 0 + Value *X; + if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) + return ConstantInt::getNullValue(Op0->getType()); + + // If the two operands are negated, return 0. + if (isKnownNegation(Op0, Op1)) + return ConstantInt::getNullValue(Op0->getType()); + return simplifyRem(Instruction::SRem, Op0, Op1, Q, MaxRecurse); } @@ -1140,10 +1171,14 @@ static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, // 0 shift by X -> 0 if (match(Op0, m_Zero())) - return Op0; + return Constant::getNullValue(Op0->getType()); // X shift by 0 -> X - if (match(Op1, m_Zero())) + // Shift-by-sign-extended bool must be shift-by-0 because shift-by-all-ones + // would be poison. + Value *X; + if (match(Op1, m_Zero()) || + (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))) return Op0; // Fold undefined shifts. @@ -1177,7 +1212,7 @@ static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, return nullptr; } -/// \brief Given operands for an Shl, LShr or AShr, see if we can +/// Given operands for an Shl, LShr or AShr, see if we can /// fold the result. If not, this returns null. static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, bool isExact, const SimplifyQuery &Q, @@ -1220,6 +1255,13 @@ static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, Value *X; if (match(Op0, m_Exact(m_Shr(m_Value(X), m_Specific(Op1))))) return X; + + // shl nuw i8 C, %x -> C iff C has sign bit set. + if (isNUW && match(Op0, m_Negative())) + return Op0; + // NOTE: could use computeKnownBits() / LazyValueInfo, + // but the cost-benefit analysis suggests it isn't worth it. + return nullptr; } @@ -1257,9 +1299,10 @@ static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, MaxRecurse)) return V; - // all ones >>a X -> all ones + // all ones >>a X -> -1 + // Do not return Op0 because it may contain undef elements if it's a vector. if (match(Op0, m_AllOnes())) - return Op0; + return Constant::getAllOnesValue(Op0->getType()); // (X << A) >> A -> X Value *X; @@ -1295,7 +1338,7 @@ static Value *simplifyUnsignedRangeCheck(ICmpInst *ZeroICmp, ICmpInst::isUnsigned(UnsignedPred)) ; else if (match(UnsignedICmp, - m_ICmp(UnsignedPred, m_Value(Y), m_Specific(X))) && + m_ICmp(UnsignedPred, m_Specific(Y), m_Value(X))) && ICmpInst::isUnsigned(UnsignedPred)) UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred); else @@ -1413,6 +1456,43 @@ static Value *simplifyAndOrOfICmpsWithConstants(ICmpInst *Cmp0, ICmpInst *Cmp1, return nullptr; } +static Value *simplifyAndOrOfICmpsWithZero(ICmpInst *Cmp0, ICmpInst *Cmp1, + bool IsAnd) { + ICmpInst::Predicate P0 = Cmp0->getPredicate(), P1 = Cmp1->getPredicate(); + if (!match(Cmp0->getOperand(1), m_Zero()) || + !match(Cmp1->getOperand(1), m_Zero()) || P0 != P1) + return nullptr; + + if ((IsAnd && P0 != ICmpInst::ICMP_NE) || (!IsAnd && P1 != ICmpInst::ICMP_EQ)) + return nullptr; + + // We have either "(X == 0 || Y == 0)" or "(X != 0 && Y != 0)". + Value *X = Cmp0->getOperand(0); + Value *Y = Cmp1->getOperand(0); + + // If one of the compares is a masked version of a (not) null check, then + // that compare implies the other, so we eliminate the other. Optionally, look + // through a pointer-to-int cast to match a null check of a pointer type. + + // (X == 0) || (([ptrtoint] X & ?) == 0) --> ([ptrtoint] X & ?) == 0 + // (X == 0) || ((? & [ptrtoint] X) == 0) --> (? & [ptrtoint] X) == 0 + // (X != 0) && (([ptrtoint] X & ?) != 0) --> ([ptrtoint] X & ?) != 0 + // (X != 0) && ((? & [ptrtoint] X) != 0) --> (? & [ptrtoint] X) != 0 + if (match(Y, m_c_And(m_Specific(X), m_Value())) || + match(Y, m_c_And(m_PtrToInt(m_Specific(X)), m_Value()))) + return Cmp1; + + // (([ptrtoint] Y & ?) == 0) || (Y == 0) --> ([ptrtoint] Y & ?) == 0 + // ((? & [ptrtoint] Y) == 0) || (Y == 0) --> (? & [ptrtoint] Y) == 0 + // (([ptrtoint] Y & ?) != 0) && (Y != 0) --> ([ptrtoint] Y & ?) != 0 + // ((? & [ptrtoint] Y) != 0) && (Y != 0) --> (? & [ptrtoint] Y) != 0 + if (match(X, m_c_And(m_Specific(Y), m_Value())) || + match(X, m_c_And(m_PtrToInt(m_Specific(Y)), m_Value()))) + return Cmp0; + + return nullptr; +} + static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) { // (icmp (add V, C0), C1) & (icmp V, C0) ICmpInst::Predicate Pred0, Pred1; @@ -1473,6 +1553,9 @@ static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, true)) return X; + if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, true)) + return X; + if (Value *X = simplifyAndOfICmpsWithAdd(Op0, Op1)) return X; if (Value *X = simplifyAndOfICmpsWithAdd(Op1, Op0)) @@ -1541,6 +1624,9 @@ static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, false)) return X; + if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, false)) + return X; + if (Value *X = simplifyOrOfICmpsWithAdd(Op0, Op1)) return X; if (Value *X = simplifyOrOfICmpsWithAdd(Op1, Op0)) @@ -1638,7 +1724,7 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, // X & 0 = 0 if (match(Op1, m_Zero())) - return Op1; + return Constant::getNullValue(Op0->getType()); // X & -1 = X if (match(Op1, m_AllOnes())) @@ -1733,21 +1819,16 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return C; // X | undef -> -1 - if (match(Op1, m_Undef())) + // X | -1 = -1 + // Do not return Op1 because it may contain undef elements if it's a vector. + if (match(Op1, m_Undef()) || match(Op1, m_AllOnes())) return Constant::getAllOnesValue(Op0->getType()); // X | X = X - if (Op0 == Op1) - return Op0; - // X | 0 = X - if (match(Op1, m_Zero())) + if (Op0 == Op1 || match(Op1, m_Zero())) return Op0; - // X | -1 = -1 - if (match(Op1, m_AllOnes())) - return Op1; - // A | ~A = ~A | A = -1 if (match(Op0, m_Not(m_Specific(Op1))) || match(Op1, m_Not(m_Specific(Op0)))) @@ -2051,9 +2132,12 @@ computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI, ConstantInt *LHSOffsetCI = dyn_cast<ConstantInt>(LHSOffset); ConstantInt *RHSOffsetCI = dyn_cast<ConstantInt>(RHSOffset); uint64_t LHSSize, RHSSize; + ObjectSizeOpts Opts; + Opts.NullIsUnknownSize = + NullPointerIsDefined(cast<AllocaInst>(LHS)->getFunction()); if (LHSOffsetCI && RHSOffsetCI && - getObjectSize(LHS, LHSSize, DL, TLI) && - getObjectSize(RHS, RHSSize, DL, TLI)) { + getObjectSize(LHS, LHSSize, DL, TLI, Opts) && + getObjectSize(RHS, RHSSize, DL, TLI, Opts)) { const APInt &LHSOffsetValue = LHSOffsetCI->getValue(); const APInt &RHSOffsetValue = RHSOffsetCI->getValue(); if (!LHSOffsetValue.isNegative() && @@ -2442,6 +2526,20 @@ static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper) { static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + Type *ITy = GetCompareTy(RHS); // The return type. + + Value *X; + // Sign-bit checks can be optimized to true/false after unsigned + // floating-point casts: + // icmp slt (bitcast (uitofp X)), 0 --> false + // icmp sgt (bitcast (uitofp X)), -1 --> true + if (match(LHS, m_BitCast(m_UIToFP(m_Value(X))))) { + if (Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero())) + return ConstantInt::getFalse(ITy); + if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes())) + return ConstantInt::getTrue(ITy); + } + const APInt *C; if (!match(RHS, m_APInt(C))) return nullptr; @@ -2449,9 +2547,9 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, // Rule out tautological comparisons (eg., ult 0 or uge 0). ConstantRange RHS_CR = ConstantRange::makeExactICmpRegion(Pred, *C); if (RHS_CR.isEmptySet()) - return ConstantInt::getFalse(GetCompareTy(RHS)); + return ConstantInt::getFalse(ITy); if (RHS_CR.isFullSet()) - return ConstantInt::getTrue(GetCompareTy(RHS)); + return ConstantInt::getTrue(ITy); // Find the range of possible values for binary operators. unsigned Width = C->getBitWidth(); @@ -2469,9 +2567,9 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, if (!LHS_CR.isFullSet()) { if (RHS_CR.contains(LHS_CR)) - return ConstantInt::getTrue(GetCompareTy(RHS)); + return ConstantInt::getTrue(ITy); if (RHS_CR.inverse().contains(LHS_CR)) - return ConstantInt::getFalse(GetCompareTy(RHS)); + return ConstantInt::getFalse(ITy); } return nullptr; @@ -3008,8 +3106,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Type *ITy = GetCompareTy(LHS); // The return type. // icmp X, X -> true/false - // X icmp undef -> true/false. For example, icmp ugt %X, undef -> false - // because X could be 0. + // icmp X, undef -> true/false because undef could be X. if (LHS == RHS || isa<UndefValue>(RHS)) return ConstantInt::get(ITy, CmpInst::isTrueWhenEqual(Pred)); @@ -3309,6 +3406,12 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, return getTrue(RetTy); } + // NaN is unordered; NaN is not ordered. + assert((FCmpInst::isOrdered(Pred) || FCmpInst::isUnordered(Pred)) && + "Comparison must be either ordered or unordered"); + if (match(RHS, m_NaN())) + return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred)); + // fcmp pred x, undef and fcmp pred undef, x // fold to true if unordered, false if ordered if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS)) { @@ -3328,15 +3431,6 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, // Handle fcmp with constant RHS. const APFloat *C; if (match(RHS, m_APFloat(C))) { - // If the constant is a nan, see if we can fold the comparison based on it. - if (C->isNaN()) { - if (FCmpInst::isOrdered(Pred)) // True "if ordered and foo" - return getFalse(RetTy); - assert(FCmpInst::isUnordered(Pred) && - "Comparison must be either ordered or unordered!"); - // True if unordered. - return getTrue(RetTy); - } // Check whether the constant is an infinity. if (C->isInfinity()) { if (C->isNegative()) { @@ -3475,6 +3569,17 @@ static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, } } + // Same for GEPs. + if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { + if (MaxRecurse) { + SmallVector<Value *, 8> NewOps(GEP->getNumOperands()); + transform(GEP->operands(), NewOps.begin(), + [&](Value *V) { return V == Op ? RepOp : V; }); + return SimplifyGEPInst(GEP->getSourceElementType(), NewOps, Q, + MaxRecurse - 1); + } + } + // TODO: We could hand off more cases to instsimplify here. // If all operands are constant after substituting Op for RepOp then we can @@ -3581,24 +3686,6 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, TrueVal, FalseVal)) return V; - if (CondVal->hasOneUse()) { - const APInt *C; - if (match(CmpRHS, m_APInt(C))) { - // X < MIN ? T : F --> F - if (Pred == ICmpInst::ICMP_SLT && C->isMinSignedValue()) - return FalseVal; - // X < MIN ? T : F --> F - if (Pred == ICmpInst::ICMP_ULT && C->isMinValue()) - return FalseVal; - // X > MAX ? T : F --> F - if (Pred == ICmpInst::ICMP_SGT && C->isMaxSignedValue()) - return FalseVal; - // X > MAX ? T : F --> F - if (Pred == ICmpInst::ICMP_UGT && C->isMaxValue()) - return FalseVal; - } - } - // If we have an equality comparison, then we know the value in one of the // arms of the select. See if substituting this value into the arm and // simplifying the result yields the same value as the other arm. @@ -3631,37 +3718,38 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, /// Given operands for a SelectInst, see if we can fold the result. /// If not, this returns null. -static Value *SimplifySelectInst(Value *CondVal, Value *TrueVal, - Value *FalseVal, const SimplifyQuery &Q, - unsigned MaxRecurse) { - // select true, X, Y -> X - // select false, X, Y -> Y - if (Constant *CB = dyn_cast<Constant>(CondVal)) { - if (Constant *CT = dyn_cast<Constant>(TrueVal)) - if (Constant *CF = dyn_cast<Constant>(FalseVal)) - return ConstantFoldSelectInstruction(CB, CT, CF); - if (CB->isAllOnesValue()) +static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (auto *CondC = dyn_cast<Constant>(Cond)) { + if (auto *TrueC = dyn_cast<Constant>(TrueVal)) + if (auto *FalseC = dyn_cast<Constant>(FalseVal)) + return ConstantFoldSelectInstruction(CondC, TrueC, FalseC); + + // select undef, X, Y -> X or Y + if (isa<UndefValue>(CondC)) + return isa<Constant>(FalseVal) ? FalseVal : TrueVal; + + // TODO: Vector constants with undef elements don't simplify. + + // select true, X, Y -> X + if (CondC->isAllOnesValue()) return TrueVal; - if (CB->isNullValue()) + // select false, X, Y -> Y + if (CondC->isNullValue()) return FalseVal; } - // select C, X, X -> X + // select ?, X, X -> X if (TrueVal == FalseVal) return TrueVal; - if (isa<UndefValue>(CondVal)) { // select undef, X, Y -> X or Y - if (isa<Constant>(FalseVal)) - return FalseVal; - return TrueVal; - } - if (isa<UndefValue>(TrueVal)) // select C, undef, X -> X + if (isa<UndefValue>(TrueVal)) // select ?, undef, X -> X return FalseVal; - if (isa<UndefValue>(FalseVal)) // select C, X, undef -> X + if (isa<UndefValue>(FalseVal)) // select ?, X, undef -> X return TrueVal; if (Value *V = - simplifySelectWithICmpCond(CondVal, TrueVal, FalseVal, Q, MaxRecurse)) + simplifySelectWithICmpCond(Cond, TrueVal, FalseVal, Q, MaxRecurse)) return V; return nullptr; @@ -3697,7 +3785,7 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, if (Ops.size() == 2) { // getelementptr P, 0 -> P. - if (match(Ops[1], m_Zero())) + if (match(Ops[1], m_Zero()) && Ops[0]->getType() == GEPTy) return Ops[0]; Type *Ty = SrcTy; @@ -3706,13 +3794,13 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, uint64_t C; uint64_t TyAllocSize = Q.DL.getTypeAllocSize(Ty); // getelementptr P, N -> P if P points to a type of zero size. - if (TyAllocSize == 0) + if (TyAllocSize == 0 && Ops[0]->getType() == GEPTy) return Ops[0]; // The following transforms are only safe if the ptrtoint cast // doesn't truncate the pointers. if (Ops[1]->getType()->getScalarSizeInBits() == - Q.DL.getPointerSizeInBits(AS)) { + Q.DL.getIndexSizeInBits(AS)) { auto PtrToIntOrZero = [GEPTy](Value *P) -> Value * { if (match(P, m_Zero())) return Constant::getNullValue(GEPTy); @@ -3752,10 +3840,10 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, if (Q.DL.getTypeAllocSize(LastType) == 1 && all_of(Ops.slice(1).drop_back(1), [](Value *Idx) { return match(Idx, m_Zero()); })) { - unsigned PtrWidth = - Q.DL.getPointerSizeInBits(Ops[0]->getType()->getPointerAddressSpace()); - if (Q.DL.getTypeSizeInBits(Ops.back()->getType()) == PtrWidth) { - APInt BasePtrOffset(PtrWidth, 0); + unsigned IdxWidth = + Q.DL.getIndexSizeInBits(Ops[0]->getType()->getPointerAddressSpace()); + if (Q.DL.getTypeSizeInBits(Ops.back()->getType()) == IdxWidth) { + APInt BasePtrOffset(IdxWidth, 0); Value *StrippedBasePtr = Ops[0]->stripAndAccumulateInBoundsConstantOffsets(Q.DL, BasePtrOffset); @@ -3838,12 +3926,13 @@ Value *llvm::SimplifyInsertElementInst(Value *Vec, Value *Val, Value *Idx, // Fold into undef if index is out of bounds. if (auto *CI = dyn_cast<ConstantInt>(Idx)) { uint64_t NumElements = cast<VectorType>(Vec->getType())->getNumElements(); - if (CI->uge(NumElements)) return UndefValue::get(Vec->getType()); } - // TODO: We should also fold if index is iteslf an undef. + // If index is undef, it might be out of bounds (see above case) + if (isa<UndefValue>(Idx)) + return UndefValue::get(Vec->getType()); return nullptr; } @@ -3896,10 +3985,13 @@ static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const SimplifyQ // If extracting a specified index from the vector, see if we can recursively // find a previously computed scalar that was inserted into the vector. - if (auto *IdxC = dyn_cast<ConstantInt>(Idx)) - if (IdxC->getValue().ule(Vec->getType()->getVectorNumElements())) - if (Value *Elt = findScalarElement(Vec, IdxC->getZExtValue())) - return Elt; + if (auto *IdxC = dyn_cast<ConstantInt>(Idx)) { + if (IdxC->getValue().uge(Vec->getType()->getVectorNumElements())) + // definitely out of bounds, thus undefined result + return UndefValue::get(Vec->getType()->getVectorElementType()); + if (Value *Elt = findScalarElement(Vec, IdxC->getZExtValue())) + return Elt; + } // An undef extract index can be arbitrarily chosen to be an out-of-range // index value, which would result in the instruction being undef. @@ -3942,7 +4034,7 @@ static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) { // instruction, we cannot return X as the result of the PHI node unless it // dominates the PHI block. if (HasUndefInput) - return ValueDominatesPHI(CommonValue, PN, Q.DT) ? CommonValue : nullptr; + return valueDominatesPHI(CommonValue, PN, Q.DT) ? CommonValue : nullptr; return CommonValue; } @@ -4119,6 +4211,28 @@ Value *llvm::SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask, return ::SimplifyShuffleVectorInst(Op0, Op1, Mask, RetTy, Q, RecursionLimit); } +static Constant *propagateNaN(Constant *In) { + // If the input is a vector with undef elements, just return a default NaN. + if (!In->isNaN()) + return ConstantFP::getNaN(In->getType()); + + // Propagate the existing NaN constant when possible. + // TODO: Should we quiet a signaling NaN? + return In; +} + +static Constant *simplifyFPBinop(Value *Op0, Value *Op1) { + if (isa<UndefValue>(Op0) || isa<UndefValue>(Op1)) + return ConstantFP::getNaN(Op0->getType()); + + if (match(Op0, m_NaN())) + return propagateNaN(cast<Constant>(Op0)); + if (match(Op1, m_NaN())) + return propagateNaN(cast<Constant>(Op1)); + + return nullptr; +} + /// Given operands for an FAdd, see if we can fold the result. If not, this /// returns null. static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, @@ -4126,29 +4240,28 @@ static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q)) return C; + if (Constant *C = simplifyFPBinop(Op0, Op1)) + return C; + // fadd X, -0 ==> X - if (match(Op1, m_NegZero())) + if (match(Op1, m_NegZeroFP())) return Op0; // fadd X, 0 ==> X, when we know X is not -0 - if (match(Op1, m_Zero()) && + if (match(Op1, m_PosZeroFP()) && (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) return Op0; - // fadd [nnan ninf] X, (fsub [nnan ninf] 0, X) ==> 0 - // where nnan and ninf have to occur at least once somewhere in this - // expression - Value *SubOp = nullptr; - if (match(Op1, m_FSub(m_AnyZero(), m_Specific(Op0)))) - SubOp = Op1; - else if (match(Op0, m_FSub(m_AnyZero(), m_Specific(Op1)))) - SubOp = Op0; - if (SubOp) { - Instruction *FSub = cast<Instruction>(SubOp); - if ((FMF.noNaNs() || FSub->hasNoNaNs()) && - (FMF.noInfs() || FSub->hasNoInfs())) - return Constant::getNullValue(Op0->getType()); - } + // With nnan: (+/-0.0 - X) + X --> 0.0 (and commuted variant) + // We don't have to explicitly exclude infinities (ninf): INF + -INF == NaN. + // Negative zeros are allowed because we always end up with positive zero: + // X = -0.0: (-0.0 - (-0.0)) + (-0.0) == ( 0.0) + (-0.0) == 0.0 + // X = -0.0: ( 0.0 - (-0.0)) + (-0.0) == ( 0.0) + (-0.0) == 0.0 + // X = 0.0: (-0.0 - ( 0.0)) + ( 0.0) == (-0.0) + ( 0.0) == 0.0 + // X = 0.0: ( 0.0 - ( 0.0)) + ( 0.0) == ( 0.0) + ( 0.0) == 0.0 + if (FMF.noNaNs() && (match(Op0, m_FSub(m_AnyZeroFP(), m_Specific(Op1))) || + match(Op1, m_FSub(m_AnyZeroFP(), m_Specific(Op0))))) + return ConstantFP::getNullValue(Op0->getType()); return nullptr; } @@ -4160,23 +4273,27 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) return C; - // fsub X, 0 ==> X - if (match(Op1, m_Zero())) + if (Constant *C = simplifyFPBinop(Op0, Op1)) + return C; + + // fsub X, +0 ==> X + if (match(Op1, m_PosZeroFP())) return Op0; // fsub X, -0 ==> X, when we know X is not -0 - if (match(Op1, m_NegZero()) && + if (match(Op1, m_NegZeroFP()) && (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) return Op0; // fsub -0.0, (fsub -0.0, X) ==> X Value *X; - if (match(Op0, m_NegZero()) && match(Op1, m_FSub(m_NegZero(), m_Value(X)))) + if (match(Op0, m_NegZeroFP()) && + match(Op1, m_FSub(m_NegZeroFP(), m_Value(X)))) return X; // fsub 0.0, (fsub 0.0, X) ==> X if signed zeros are ignored. - if (FMF.noSignedZeros() && match(Op0, m_AnyZero()) && - match(Op1, m_FSub(m_AnyZero(), m_Value(X)))) + if (FMF.noSignedZeros() && match(Op0, m_AnyZeroFP()) && + match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X)))) return X; // fsub nnan x, x ==> 0.0 @@ -4192,13 +4309,25 @@ static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q)) return C; + if (Constant *C = simplifyFPBinop(Op0, Op1)) + return C; + // fmul X, 1.0 ==> X if (match(Op1, m_FPOne())) return Op0; // fmul nnan nsz X, 0 ==> 0 - if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op1, m_AnyZero())) - return Op1; + if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op1, m_AnyZeroFP())) + return ConstantFP::getNullValue(Op0->getType()); + + // sqrt(X) * sqrt(X) --> X, if we can: + // 1. Remove the intermediate rounding (reassociate). + // 2. Ignore non-zero negative numbers because sqrt would produce NAN. + // 3. Ignore -0.0 because sqrt(-0.0) == -0.0, but -0.0 * -0.0 == 0.0. + Value *X; + if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::sqrt>(m_Value(X))) && + FMF.allowReassoc() && FMF.noNaNs() && FMF.noSignedZeros()) + return X; return nullptr; } @@ -4224,13 +4353,8 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q)) return C; - // undef / X -> undef (the undef could be a snan). - if (match(Op0, m_Undef())) - return Op0; - - // X / undef -> undef - if (match(Op1, m_Undef())) - return Op1; + if (Constant *C = simplifyFPBinop(Op0, Op1)) + return C; // X / 1.0 -> X if (match(Op1, m_FPOne())) @@ -4239,14 +4363,20 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, // 0 / X -> 0 // Requires that NaNs are off (X could be zero) and signed zeroes are // ignored (X could be positive or negative, so the output sign is unknown). - if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op0, m_AnyZero())) - return Op0; + if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op0, m_AnyZeroFP())) + return ConstantFP::getNullValue(Op0->getType()); if (FMF.noNaNs()) { // X / X -> 1.0 is legal when NaNs are ignored. + // We can ignore infinities because INF/INF is NaN. if (Op0 == Op1) return ConstantFP::get(Op0->getType(), 1.0); + // (X * Y) / Y --> X if we can reassociate to the above form. + Value *X; + if (FMF.allowReassoc() && match(Op0, m_c_FMul(m_Value(X), m_Specific(Op1)))) + return X; + // -X / X -> -1.0 and // X / -X -> -1.0 are legal when NaNs are ignored. // We can ignore signed zeros because +-0.0/+-0.0 is NaN and ignored. @@ -4270,19 +4400,20 @@ static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q)) return C; - // undef % X -> undef (the undef could be a snan). - if (match(Op0, m_Undef())) - return Op0; - - // X % undef -> undef - if (match(Op1, m_Undef())) - return Op1; + if (Constant *C = simplifyFPBinop(Op0, Op1)) + return C; - // 0 % X -> 0 - // Requires that NaNs are off (X could be zero) and signed zeroes are - // ignored (X could be positive or negative, so the output sign is unknown). - if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op0, m_AnyZero())) - return Op0; + // Unlike fdiv, the result of frem always matches the sign of the dividend. + // The constant match may include undef elements in a vector, so return a full + // zero constant as the result. + if (FMF.noNaNs()) { + // +0 % X -> 0 + if (match(Op0, m_PosZeroFP())) + return ConstantFP::getNullValue(Op0->getType()); + // -0 % X -> -0 + if (match(Op0, m_NegZeroFP())) + return ConstantFP::getNegativeZero(Op0->getType()); + } return nullptr; } @@ -4489,28 +4620,55 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, } } + Value *IIOperand = *ArgBegin; + Value *X; switch (IID) { case Intrinsic::fabs: { - if (SignBitMustBeZero(*ArgBegin, Q.TLI)) - return *ArgBegin; + if (SignBitMustBeZero(IIOperand, Q.TLI)) + return IIOperand; return nullptr; } case Intrinsic::bswap: { - Value *IIOperand = *ArgBegin; - Value *X = nullptr; // bswap(bswap(x)) -> x if (match(IIOperand, m_BSwap(m_Value(X)))) return X; return nullptr; } case Intrinsic::bitreverse: { - Value *IIOperand = *ArgBegin; - Value *X = nullptr; // bitreverse(bitreverse(x)) -> x if (match(IIOperand, m_BitReverse(m_Value(X)))) return X; return nullptr; } + case Intrinsic::exp: { + // exp(log(x)) -> x + if (Q.CxtI->hasAllowReassoc() && + match(IIOperand, m_Intrinsic<Intrinsic::log>(m_Value(X)))) + return X; + return nullptr; + } + case Intrinsic::exp2: { + // exp2(log2(x)) -> x + if (Q.CxtI->hasAllowReassoc() && + match(IIOperand, m_Intrinsic<Intrinsic::log2>(m_Value(X)))) + return X; + return nullptr; + } + case Intrinsic::log: { + // log(exp(x)) -> x + if (Q.CxtI->hasAllowReassoc() && + match(IIOperand, m_Intrinsic<Intrinsic::exp>(m_Value(X)))) + return X; + return nullptr; + } + case Intrinsic::log2: { + // log2(exp2(x)) -> x + if (Q.CxtI->hasAllowReassoc() && + match(IIOperand, m_Intrinsic<Intrinsic::exp2>(m_Value(X)))) { + return X; + } + return nullptr; + } default: return nullptr; } @@ -4575,6 +4733,14 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, return LHS; } return nullptr; + case Intrinsic::maxnum: + case Intrinsic::minnum: + // If one argument is NaN, return the other argument. + if (match(LHS, m_NaN())) + return RHS; + if (match(RHS, m_NaN())) + return LHS; + return nullptr; default: return nullptr; } @@ -4812,7 +4978,7 @@ Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, return Result == I ? UndefValue::get(I->getType()) : Result; } -/// \brief Implementation of recursive simplification through an instruction's +/// Implementation of recursive simplification through an instruction's /// uses. /// /// This is the common implementation of the recursive simplification routines. |