diff options
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolution.cpp | 517 |
1 files changed, 186 insertions, 331 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 26a9a5ddf1ea7..48c686b732608 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -79,6 +79,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolutionDivision.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -86,7 +87,6 @@ #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Constant.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" @@ -848,273 +848,14 @@ static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops, } } -// Returns the size of the SCEV S. -static inline int sizeOfSCEV(const SCEV *S) { - struct FindSCEVSize { - int Size = 0; - - FindSCEVSize() = default; - - bool follow(const SCEV *S) { - ++Size; - // Keep looking at all operands of S. - return true; - } - - bool isDone() const { - return false; - } - }; - - FindSCEVSize F; - SCEVTraversal<FindSCEVSize> ST(F); - ST.visitAll(S); - return F.Size; -} - -/// Returns true if the subtree of \p S contains at least HugeExprThreshold -/// nodes. -static bool isHugeExpression(const SCEV *S) { - return S->getExpressionSize() >= HugeExprThreshold; -} - -/// Returns true of \p Ops contains a huge SCEV (see definition above). +/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at +/// least HugeExprThreshold nodes). static bool hasHugeExpression(ArrayRef<const SCEV *> Ops) { - return any_of(Ops, isHugeExpression); + return any_of(Ops, [](const SCEV *S) { + return S->getExpressionSize() >= HugeExprThreshold; + }); } -namespace { - -struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> { -public: - // Computes the Quotient and Remainder of the division of Numerator by - // Denominator. - static void divide(ScalarEvolution &SE, const SCEV *Numerator, - const SCEV *Denominator, const SCEV **Quotient, - const SCEV **Remainder) { - assert(Numerator && Denominator && "Uninitialized SCEV"); - - SCEVDivision D(SE, Numerator, Denominator); - - // Check for the trivial case here to avoid having to check for it in the - // rest of the code. - if (Numerator == Denominator) { - *Quotient = D.One; - *Remainder = D.Zero; - return; - } - - if (Numerator->isZero()) { - *Quotient = D.Zero; - *Remainder = D.Zero; - return; - } - - // A simple case when N/1. The quotient is N. - if (Denominator->isOne()) { - *Quotient = Numerator; - *Remainder = D.Zero; - return; - } - - // Split the Denominator when it is a product. - if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) { - const SCEV *Q, *R; - *Quotient = Numerator; - for (const SCEV *Op : T->operands()) { - divide(SE, *Quotient, Op, &Q, &R); - *Quotient = Q; - - // Bail out when the Numerator is not divisible by one of the terms of - // the Denominator. - if (!R->isZero()) { - *Quotient = D.Zero; - *Remainder = Numerator; - return; - } - } - *Remainder = D.Zero; - return; - } - - D.visit(Numerator); - *Quotient = D.Quotient; - *Remainder = D.Remainder; - } - - // Except in the trivial case described above, we do not know how to divide - // Expr by Denominator for the following functions with empty implementation. - void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {} - void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {} - void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {} - void visitUDivExpr(const SCEVUDivExpr *Numerator) {} - void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {} - void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {} - void visitSMinExpr(const SCEVSMinExpr *Numerator) {} - void visitUMinExpr(const SCEVUMinExpr *Numerator) {} - void visitUnknown(const SCEVUnknown *Numerator) {} - void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} - - void visitConstant(const SCEVConstant *Numerator) { - if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) { - APInt NumeratorVal = Numerator->getAPInt(); - APInt DenominatorVal = D->getAPInt(); - uint32_t NumeratorBW = NumeratorVal.getBitWidth(); - uint32_t DenominatorBW = DenominatorVal.getBitWidth(); - - if (NumeratorBW > DenominatorBW) - DenominatorVal = DenominatorVal.sext(NumeratorBW); - else if (NumeratorBW < DenominatorBW) - NumeratorVal = NumeratorVal.sext(DenominatorBW); - - APInt QuotientVal(NumeratorVal.getBitWidth(), 0); - APInt RemainderVal(NumeratorVal.getBitWidth(), 0); - APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal); - Quotient = SE.getConstant(QuotientVal); - Remainder = SE.getConstant(RemainderVal); - return; - } - } - - void visitAddRecExpr(const SCEVAddRecExpr *Numerator) { - const SCEV *StartQ, *StartR, *StepQ, *StepR; - if (!Numerator->isAffine()) - return cannotDivide(Numerator); - divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); - divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); - // Bail out if the types do not match. - Type *Ty = Denominator->getType(); - if (Ty != StartQ->getType() || Ty != StartR->getType() || - Ty != StepQ->getType() || Ty != StepR->getType()) - return cannotDivide(Numerator); - Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), - Numerator->getNoWrapFlags()); - Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), - Numerator->getNoWrapFlags()); - } - - void visitAddExpr(const SCEVAddExpr *Numerator) { - SmallVector<const SCEV *, 2> Qs, Rs; - Type *Ty = Denominator->getType(); - - for (const SCEV *Op : Numerator->operands()) { - const SCEV *Q, *R; - divide(SE, Op, Denominator, &Q, &R); - - // Bail out if types do not match. - if (Ty != Q->getType() || Ty != R->getType()) - return cannotDivide(Numerator); - - Qs.push_back(Q); - Rs.push_back(R); - } - - if (Qs.size() == 1) { - Quotient = Qs[0]; - Remainder = Rs[0]; - return; - } - - Quotient = SE.getAddExpr(Qs); - Remainder = SE.getAddExpr(Rs); - } - - void visitMulExpr(const SCEVMulExpr *Numerator) { - SmallVector<const SCEV *, 2> Qs; - Type *Ty = Denominator->getType(); - - bool FoundDenominatorTerm = false; - for (const SCEV *Op : Numerator->operands()) { - // Bail out if types do not match. - if (Ty != Op->getType()) - return cannotDivide(Numerator); - - if (FoundDenominatorTerm) { - Qs.push_back(Op); - continue; - } - - // Check whether Denominator divides one of the product operands. - const SCEV *Q, *R; - divide(SE, Op, Denominator, &Q, &R); - if (!R->isZero()) { - Qs.push_back(Op); - continue; - } - - // Bail out if types do not match. - if (Ty != Q->getType()) - return cannotDivide(Numerator); - - FoundDenominatorTerm = true; - Qs.push_back(Q); - } - - if (FoundDenominatorTerm) { - Remainder = Zero; - if (Qs.size() == 1) - Quotient = Qs[0]; - else - Quotient = SE.getMulExpr(Qs); - return; - } - - if (!isa<SCEVUnknown>(Denominator)) - return cannotDivide(Numerator); - - // The Remainder is obtained by replacing Denominator by 0 in Numerator. - ValueToValueMap RewriteMap; - RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = - cast<SCEVConstant>(Zero)->getValue(); - Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); - - if (Remainder->isZero()) { - // The Quotient is obtained by replacing Denominator by 1 in Numerator. - RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = - cast<SCEVConstant>(One)->getValue(); - Quotient = - SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); - return; - } - - // Quotient is (Numerator - Remainder) divided by Denominator. - const SCEV *Q, *R; - const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); - // This SCEV does not seem to simplify: fail the division here. - if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) - return cannotDivide(Numerator); - divide(SE, Diff, Denominator, &Q, &R); - if (R != Zero) - return cannotDivide(Numerator); - Quotient = Q; - } - -private: - SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, - const SCEV *Denominator) - : SE(S), Denominator(Denominator) { - Zero = SE.getZero(Denominator->getType()); - One = SE.getOne(Denominator->getType()); - - // We generally do not know how to divide Expr by Denominator. We - // initialize the division to a "cannot divide" state to simplify the rest - // of the code. - cannotDivide(Numerator); - } - - // Convenience function for giving up on the division. We set the quotient to - // be equal to zero and the remainder to be equal to the numerator. - void cannotDivide(const SCEV *Numerator) { - Quotient = Zero; - Remainder = Numerator; - } - - ScalarEvolution &SE; - const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One; -}; - -} // end anonymous namespace - //===----------------------------------------------------------------------===// // Simple SCEV method implementations //===----------------------------------------------------------------------===// @@ -1612,7 +1353,7 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr) { - const APInt C = ConstantTerm->getAPInt(); + const APInt &C = ConstantTerm->getAPInt(); const unsigned BitWidth = C.getBitWidth(); // Find number of trailing zeros of (x + y + ...) w/o the C first: uint32_t TZ = BitWidth; @@ -2455,6 +2196,11 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, if (Depth > MaxArithDepth || hasHugeExpression(Ops)) return getOrCreateAddExpr(Ops, Flags); + if (SCEV *S = std::get<0>(findExistingSCEVInCache(scAddExpr, Ops))) { + static_cast<SCEVAddExpr *>(S)->setNoWrapFlags(Flags); + return S; + } + // Okay, check to see if the same value occurs in the operand list more than // once. If so, merge them together into an multiply expression. Since we // sorted the list, these values are required to be adjacent. @@ -2930,10 +2676,17 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags); - // Limit recursion calls depth. - if (Depth > MaxArithDepth || hasHugeExpression(Ops)) + // Limit recursion calls depth, but fold all-constant expressions. + // `Ops` is sorted, so it's enough to check just last one. + if ((Depth > MaxArithDepth || hasHugeExpression(Ops)) && + !isa<SCEVConstant>(Ops.back())) return getOrCreateMulExpr(Ops, Flags); + if (SCEV *S = std::get<0>(findExistingSCEVInCache(scMulExpr, Ops))) { + static_cast<SCEVMulExpr *>(S)->setNoWrapFlags(Flags); + return S; + } + // If there are any constants, fold them together. unsigned Idx = 0; if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { @@ -3104,8 +2857,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, // Limit max number of arguments to avoid creation of unreasonably big // SCEVAddRecs with very complex operands. if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 > - MaxAddRecSize || isHugeExpression(AddRec) || - isHugeExpression(OtherAddRec)) + MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec})) continue; bool Overflow = false; @@ -3197,6 +2949,14 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, getEffectiveSCEVType(RHS->getType()) && "SCEVUDivExpr operand types don't match!"); + FoldingSetNodeID ID; + ID.AddInteger(scUDivExpr); + ID.AddPointer(LHS); + ID.AddPointer(RHS); + void *IP = nullptr; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; + if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { if (RHSC->getValue()->isOne()) return LHS; // X udiv 1 --> x @@ -3243,9 +3003,24 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, AR->getLoop(), SCEV::FlagAnyWrap)) { const APInt &StartInt = StartC->getAPInt(); const APInt &StartRem = StartInt.urem(StepInt); - if (StartRem != 0) - LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step, - AR->getLoop(), SCEV::FlagNW); + if (StartRem != 0) { + const SCEV *NewLHS = + getAddRecExpr(getConstant(StartInt - StartRem), Step, + AR->getLoop(), SCEV::FlagNW); + if (LHS != NewLHS) { + LHS = NewLHS; + + // Reset the ID to include the new LHS, and check if it is + // already cached. + ID.clear(); + ID.AddInteger(scUDivExpr); + ID.AddPointer(LHS); + ID.AddPointer(RHS); + IP = nullptr; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; + } + } } } // (A*B)/C --> A*(B/C) if safe and B/C can be folded. @@ -3310,11 +3085,9 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, } } - FoldingSetNodeID ID; - ID.AddInteger(scUDivExpr); - ID.AddPointer(LHS); - ID.AddPointer(RHS); - void *IP = nullptr; + // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs + // changes). Make sure we get a new one. + IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator), LHS, RHS); @@ -3505,9 +3278,8 @@ ScalarEvolution::getGEPExpr(GEPOperator *GEP, : SCEV::FlagAnyWrap; const SCEV *TotalOffset = getZero(IntIdxTy); - // The array size is unimportant. The first thing we do on CurTy is getting - // its element type. - Type *CurTy = ArrayType::get(GEP->getSourceElementType(), 0); + Type *CurTy = GEP->getType(); + bool FirstIter = true; for (const SCEV *IndexExpr : IndexExprs) { // Compute the (potentially symbolic) offset in bytes for this index. if (StructType *STy = dyn_cast<StructType>(CurTy)) { @@ -3523,7 +3295,14 @@ ScalarEvolution::getGEPExpr(GEPOperator *GEP, CurTy = STy->getTypeAtIndex(Index); } else { // Update CurTy to its element type. - CurTy = cast<SequentialType>(CurTy)->getElementType(); + if (FirstIter) { + assert(isa<PointerType>(CurTy) && + "The first index of a GEP indexes a pointer"); + CurTy = GEP->getSourceElementType(); + FirstIter = false; + } else { + CurTy = GetElementPtrInst::getTypeAtIndex(CurTy, (uint64_t)0); + } // For an array, add the element offset, explicitly scaled. const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy); // Getelementptr indices are signed. @@ -3538,10 +3317,13 @@ ScalarEvolution::getGEPExpr(GEPOperator *GEP, } // Add the total offset from all the GEP indices to the base. - return getAddExpr(BaseExpr, TotalOffset, Wrap); + auto *GEPExpr = getAddExpr(BaseExpr, TotalOffset, Wrap); + assert(BaseExpr->getType() == GEPExpr->getType() && + "GEP should not change type mid-flight."); + return GEPExpr; } -std::tuple<const SCEV *, FoldingSetNodeID, void *> +std::tuple<SCEV *, FoldingSetNodeID, void *> ScalarEvolution::findExistingSCEVInCache(int SCEVType, ArrayRef<const SCEV *> Ops) { FoldingSetNodeID ID; @@ -3549,7 +3331,7 @@ ScalarEvolution::findExistingSCEVInCache(int SCEVType, ID.AddInteger(SCEVType); for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); - return std::tuple<const SCEV *, FoldingSetNodeID, void *>( + return std::tuple<SCEV *, FoldingSetNodeID, void *>( UniqueSCEVs.FindNodeOrInsertPos(ID, IP), std::move(ID), IP); } @@ -3727,6 +3509,12 @@ const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { // We can bypass creating a target-independent // constant expression and then folding it back into a ConstantInt. // This is just a compile-time optimization. + if (isa<ScalableVectorType>(AllocTy)) { + Constant *NullPtr = Constant::getNullValue(AllocTy->getPointerTo()); + Constant *One = ConstantInt::get(IntTy, 1); + Constant *GEP = ConstantExpr::getGetElementPtr(AllocTy, NullPtr, One); + return getSCEV(ConstantExpr::getPtrToInt(GEP, IntTy)); + } return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy)); } @@ -3820,7 +3608,8 @@ bool ScalarEvolution::containsAddRecurrence(const SCEV *S) { if (I != HasRecMap.end()) return I->second; - bool FoundAddRec = SCEVExprContains(S, isa<SCEVAddRecExpr, const SCEV *>); + bool FoundAddRec = + SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); }); HasRecMap.insert({S, FoundAddRec}); return FoundAddRec; } @@ -4167,23 +3956,25 @@ const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { if (!V->getType()->isPointerTy()) return V; - if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) { - return getPointerBase(Cast->getOperand()); - } else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) { - const SCEV *PtrOp = nullptr; - for (const SCEV *NAryOp : NAry->operands()) { - if (NAryOp->getType()->isPointerTy()) { - // Cannot find the base of an expression with multiple pointer operands. - if (PtrOp) - return V; - PtrOp = NAryOp; + while (true) { + if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) { + V = Cast->getOperand(); + } else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) { + const SCEV *PtrOp = nullptr; + for (const SCEV *NAryOp : NAry->operands()) { + if (NAryOp->getType()->isPointerTy()) { + // Cannot find the base of an expression with multiple pointer ops. + if (PtrOp) + return V; + PtrOp = NAryOp; + } } - } - if (!PtrOp) + if (!PtrOp) // All operands were non-pointer. + return V; + V = PtrOp; + } else // Not something we can look further into. return V; - return getPointerBase(PtrOp); } - return V; } /// Push users of the given Instruction onto the given Worklist. @@ -5740,7 +5531,7 @@ ScalarEvolution::getRangeRef(const SCEV *S, // For a SCEVUnknown, ask ValueTracking. KnownBits Known = computeKnownBits(U->getValue(), DL, 0, &AC, nullptr, &DT); if (Known.getBitWidth() != BitWidth) - Known = Known.zextOrTrunc(BitWidth, true); + Known = Known.zextOrTrunc(BitWidth); // If Known does not result in full-set, intersect with it. if (Known.getMinValue() != Known.getMaxValue() + 1) ConservativeResult = ConservativeResult.intersectWith( @@ -6032,7 +5823,7 @@ bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) { return false; // Only proceed if we can prove that I does not yield poison. - if (!programUndefinedIfFullPoison(I)) + if (!programUndefinedIfPoison(I)) return false; // At this point we know that if I is executed, then it does not wrap @@ -6112,7 +5903,7 @@ bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) { SmallVector<const Instruction *, 8> PoisonStack; // We start by assuming \c I, the post-inc add recurrence, is poison. Only - // things that are known to be fully poison under that assumption go on the + // things that are known to be poison under that assumption go on the // PoisonStack. Pushed.insert(I); PoisonStack.push_back(I); @@ -6122,7 +5913,7 @@ bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) { const Instruction *Poison = PoisonStack.pop_back_val(); for (auto *PoisonUser : Poison->users()) { - if (propagatesFullPoison(cast<Instruction>(PoisonUser))) { + if (propagatesPoison(cast<Instruction>(PoisonUser))) { if (Pushed.insert(cast<Instruction>(PoisonUser)).second) PoisonStack.push_back(cast<Instruction>(PoisonUser)); } else if (auto *BI = dyn_cast<BranchInst>(PoisonUser)) { @@ -6349,15 +6140,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { if (GetMinTrailingZeros(LHS) >= (CIVal.getBitWidth() - CIVal.countLeadingZeros())) { // Build a plain add SCEV. - const SCEV *S = getAddExpr(LHS, getSCEV(CI)); - // If the LHS of the add was an addrec and it has no-wrap flags, - // transfer the no-wrap flags, since an or won't introduce a wrap. - if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) { - const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS); - const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags( - OldAR->getNoWrapFlags()); - } - return S; + return getAddExpr(LHS, getSCEV(CI), + (SCEV::NoWrapFlags)(SCEV::FlagNUW | SCEV::FlagNSW)); } } break; @@ -6413,15 +6197,19 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { if (SA->getValue().uge(BitWidth)) break; - // It is currently not resolved how to interpret NSW for left - // shift by BitWidth - 1, so we avoid applying flags in that - // case. Remove this check (or this comment) once the situation - // is resolved. See - // http://lists.llvm.org/pipermail/llvm-dev/2015-April/084195.html - // and http://reviews.llvm.org/D8890 . + // We can safely preserve the nuw flag in all cases. It's also safe to + // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation + // requires special handling. It can be preserved as long as we're not + // left shifting by bitwidth - 1. auto Flags = SCEV::FlagAnyWrap; - if (BO->Op && SA->getValue().ult(BitWidth - 1)) - Flags = getNoWrapFlagsFromUB(BO->Op); + if (BO->Op) { + auto MulFlags = getNoWrapFlagsFromUB(BO->Op); + if ((MulFlags & SCEV::FlagNSW) && + ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1))) + Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNSW); + if (MulFlags & SCEV::FlagNUW) + Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNUW); + } Constant *X = ConstantInt::get( getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue())); @@ -6515,6 +6303,20 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { return getSCEV(U->getOperand(0)); break; + case Instruction::SDiv: + // If both operands are non-negative, this is just an udiv. + if (isKnownNonNegative(getSCEV(U->getOperand(0))) && + isKnownNonNegative(getSCEV(U->getOperand(1)))) + return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1))); + break; + + case Instruction::SRem: + // If both operands are non-negative, this is just an urem. + if (isKnownNonNegative(getSCEV(U->getOperand(0))) && + isKnownNonNegative(getSCEV(U->getOperand(1)))) + return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1))); + break; + // It's tempting to handle inttoptr and ptrtoint as no-ops, however this can // lead to pointer expressions which cannot safely be expanded to GEPs, // because ScalarEvolution doesn't respect the GEP aliasing rules when @@ -6538,7 +6340,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { case Instruction::Call: case Instruction::Invoke: - if (Value *RV = CallSite(U).getReturnedArgOperand()) + if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) return getSCEV(RV); break; } @@ -6644,7 +6446,7 @@ const SCEV *ScalarEvolution::getExitCount(const Loop *L, BasicBlock *ExitingBlock, ExitCountKind Kind) { switch (Kind) { - case Exact: + case Exact: return getBackedgeTakenInfo(L).getExact(ExitingBlock, this); case ConstantMaximum: return getBackedgeTakenInfo(L).getMax(ExitingBlock, this); @@ -6661,7 +6463,7 @@ ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L, const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L, ExitCountKind Kind) { switch (Kind) { - case Exact: + case Exact: return getBackedgeTakenInfo(L).getExact(L, this); case ConstantMaximum: return getBackedgeTakenInfo(L).getMax(this); @@ -6924,6 +6726,10 @@ void ScalarEvolution::forgetValue(Value *V) { } } +void ScalarEvolution::forgetLoopDispositions(const Loop *L) { + LoopDispositions.clear(); +} + /// Get the exact loop backedge taken count considering all loop exits. A /// computable result can only be returned for loops with all exiting blocks /// dominating the latch. howFarToZero assumes that the limit of each loop test @@ -8244,10 +8050,11 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && isKnownPositive(BackedgeTakenCount) && PN->getNumIncomingValues() == 2) { + unsigned InLoopPred = LI->contains(PN->getIncomingBlock(0)) ? 0 : 1; - const SCEV *OnBackedge = getSCEV(PN->getIncomingValue(InLoopPred)); - if (IsAvailableOnEntry(LI, DT, OnBackedge, PN->getParent())) - return OnBackedge; + Value *BackedgeVal = PN->getIncomingValue(InLoopPred); + if (LI->isLoopInvariant(BackedgeVal)) + return getSCEV(BackedgeVal); } if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) { // Okay, we know how many times the containing loop executes. If @@ -9226,9 +9033,11 @@ bool ScalarEvolution::isKnownViaInduction(ICmpInst::Predicate Pred, !isAvailableAtLoopEntry(SplitRHS.first, MDL)) return false; - return isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first) && - isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second, - SplitRHS.second); + // It seems backedge guard check is faster than entry one so in some cases + // it can speed up whole estimation by short circuit + return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second, + SplitRHS.second) && + isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first); } bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, @@ -11161,8 +10970,9 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE, // Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter. static inline bool containsParameters(SmallVectorImpl<const SCEV *> &Terms) { for (const SCEV *T : Terms) - if (SCEVExprContains(T, isa<SCEVUnknown, const SCEV *>)) + if (SCEVExprContains(T, [](const SCEV *S) { return isa<SCEVUnknown>(S); })) return true; + return false; } @@ -11411,6 +11221,51 @@ void ScalarEvolution::delinearize(const SCEV *Expr, }); } +bool ScalarEvolution::getIndexExpressionsFromGEP( + const GetElementPtrInst *GEP, SmallVectorImpl<const SCEV *> &Subscripts, + SmallVectorImpl<int> &Sizes) { + assert(Subscripts.empty() && Sizes.empty() && + "Expected output lists to be empty on entry to this function."); + assert(GEP && "getIndexExpressionsFromGEP called with a null GEP"); + Type *Ty = GEP->getPointerOperandType(); + bool DroppedFirstDim = false; + for (unsigned i = 1; i < GEP->getNumOperands(); i++) { + const SCEV *Expr = getSCEV(GEP->getOperand(i)); + if (i == 1) { + if (auto *PtrTy = dyn_cast<PointerType>(Ty)) { + Ty = PtrTy->getElementType(); + } else if (auto *ArrayTy = dyn_cast<ArrayType>(Ty)) { + Ty = ArrayTy->getElementType(); + } else { + Subscripts.clear(); + Sizes.clear(); + return false; + } + if (auto *Const = dyn_cast<SCEVConstant>(Expr)) + if (Const->getValue()->isZero()) { + DroppedFirstDim = true; + continue; + } + Subscripts.push_back(Expr); + continue; + } + + auto *ArrayTy = dyn_cast<ArrayType>(Ty); + if (!ArrayTy) { + Subscripts.clear(); + Sizes.clear(); + return false; + } + + Subscripts.push_back(Expr); + if (!(DroppedFirstDim && i == 2)) + Sizes.push_back(ArrayTy->getNumElements()); + + Ty = ArrayTy->getElementType(); + } + return !Subscripts.empty(); +} + //===----------------------------------------------------------------------===// // SCEVCallbackVH Class Implementation //===----------------------------------------------------------------------===// |