diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2017-04-16 16:01:22 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2017-04-16 16:01:22 +0000 |
commit | 71d5a2540a98c81f5bcaeb48805e0e2881f530ef (patch) | |
tree | 5343938942df402b49ec7300a1c25a2d4ccd5821 /lib/Analysis/ScalarEvolution.cpp | |
parent | 31bbf64f3a4974a2d6c8b3b27ad2f519caf74057 (diff) |
Diffstat (limited to 'lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | lib/Analysis/ScalarEvolution.cpp | 649 |
1 files changed, 442 insertions, 207 deletions
diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index ed328f12c463..ca32cf3c7c34 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -127,16 +127,35 @@ static cl::opt<unsigned> MulOpsInlineThreshold( cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(1000)); +static cl::opt<unsigned> AddOpsInlineThreshold( + "scev-addops-inline-threshold", cl::Hidden, + cl::desc("Threshold for inlining multiplication operands into a SCEV"), + cl::init(500)); + static cl::opt<unsigned> MaxSCEVCompareDepth( "scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32)); +static cl::opt<unsigned> MaxSCEVOperationsImplicationDepth( + "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, + cl::desc("Maximum depth of recursive SCEV operations implication analysis"), + cl::init(2)); + static cl::opt<unsigned> MaxValueCompareDepth( "scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2)); +static cl::opt<unsigned> + MaxAddExprDepth("scalar-evolution-max-addexpr-depth", cl::Hidden, + cl::desc("Maximum depth of recursive AddExpr"), + cl::init(32)); + +static cl::opt<unsigned> MaxConstantEvolvingDepth( + "scalar-evolution-max-constant-evolving-depth", cl::Hidden, + cl::desc("Maximum depth of recursive constant evolving"), cl::init(32)); + //===----------------------------------------------------------------------===// // SCEV class definitions //===----------------------------------------------------------------------===// @@ -145,11 +164,12 @@ static cl::opt<unsigned> MaxValueCompareDepth( // Implementation of the SCEV class. // -LLVM_DUMP_METHOD -void SCEV::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void SCEV::dump() const { print(dbgs()); dbgs() << '\n'; } +#endif void SCEV::print(raw_ostream &OS) const { switch (static_cast<SCEVTypes>(getSCEVType())) { @@ -2095,7 +2115,8 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, /// Get a canonical add expression, or something simpler if possible. const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, - SCEV::NoWrapFlags Flags) { + SCEV::NoWrapFlags Flags, + unsigned Depth) { assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && "only nuw or nsw allowed"); assert(!Ops.empty() && "Cannot get empty add!"); @@ -2134,6 +2155,10 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, if (Ops.size() == 1) return Ops[0]; } + // Limit recursion calls depth + if (Depth > MaxAddExprDepth) + return getOrCreateAddExpr(Ops, Flags); + // Okay, check to see if the same value occurs in the operand list more than // once. If so, merge them together into an multiply expression. Since we // sorted the list, these values are required to be adjacent. @@ -2205,7 +2230,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, } if (Ok) { // Evaluate the expression in the larger type. - const SCEV *Fold = getAddExpr(LargeOps, Flags); + const SCEV *Fold = getAddExpr(LargeOps, Flags, Depth + 1); // If it folds to something simple, use it. Otherwise, don't. if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold)) return getTruncateExpr(Fold, DstType); @@ -2220,6 +2245,9 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, if (Idx < Ops.size()) { bool DeletedAdd = false; while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) { + if (Ops.size() > AddOpsInlineThreshold || + Add->getNumOperands() > AddOpsInlineThreshold) + break; // If we have an add, expand the add operands onto the end of the operands // list. Ops.erase(Ops.begin()+Idx); @@ -2231,7 +2259,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, // and they are not necessarily sorted. Recurse to resort and resimplify // any operands we just acquired. if (DeletedAdd) - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } // Skip over the add expression until we get to a multiply. @@ -2266,13 +2294,14 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, Ops.push_back(getConstant(AccumulatedConstant)); for (auto &MulOp : MulOpLists) if (MulOp.first != 0) - Ops.push_back(getMulExpr(getConstant(MulOp.first), - getAddExpr(MulOp.second))); + Ops.push_back(getMulExpr( + getConstant(MulOp.first), + getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1))); if (Ops.empty()) return getZero(Ty); if (Ops.size() == 1) return Ops[0]; - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } } @@ -2297,8 +2326,8 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end()); InnerMul = getMulExpr(MulOps); } - const SCEV *One = getOne(Ty); - const SCEV *AddOne = getAddExpr(One, InnerMul); + SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul}; + const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV); if (Ops.size() == 2) return OuterMul; if (AddOp < Idx) { @@ -2309,7 +2338,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, Ops.erase(Ops.begin()+AddOp-1); } Ops.push_back(OuterMul); - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } // Check this multiply against other multiplies being added together. @@ -2337,13 +2366,15 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end()); InnerMul2 = getMulExpr(MulOps); } - const SCEV *InnerMulSum = getAddExpr(InnerMul1,InnerMul2); + SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2}; + const SCEV *InnerMulSum = + getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum); if (Ops.size() == 2) return OuterMul; Ops.erase(Ops.begin()+Idx); Ops.erase(Ops.begin()+OtherMulIdx-1); Ops.push_back(OuterMul); - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } } } @@ -2379,7 +2410,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, // This follows from the fact that the no-wrap flags on the outer add // expression are applicable on the 0th iteration, when the add recurrence // will be equal to its start value. - AddRecOps[0] = getAddExpr(LIOps, Flags); + AddRecOps[0] = getAddExpr(LIOps, Flags, Depth + 1); // Build the new addrec. Propagate the NUW and NSW flags if both the // outer add and the inner addrec are guaranteed to have no overflow. @@ -2396,7 +2427,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, Ops[i] = NewRec; break; } - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } // Okay, if there weren't any loop invariants to be folded, check to see if @@ -2420,14 +2451,15 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, OtherAddRec->op_end()); break; } - AddRecOps[i] = getAddExpr(AddRecOps[i], - OtherAddRec->getOperand(i)); + SmallVector<const SCEV *, 2> TwoOps = { + AddRecOps[i], OtherAddRec->getOperand(i)}; + AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); } Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; } // Step size has changed, so we cannot guarantee no self-wraparound. Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap); - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } // Otherwise couldn't fold anything into this recurrence. Move onto the @@ -2436,18 +2468,24 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, // Okay, it looks like we really DO need an add expr. Check to see if we // already have one, otherwise create a new one. + return getOrCreateAddExpr(Ops, Flags); +} + +const SCEV * +ScalarEvolution::getOrCreateAddExpr(SmallVectorImpl<const SCEV *> &Ops, + SCEV::NoWrapFlags Flags) { FoldingSetNodeID ID; ID.AddInteger(scAddExpr); for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); void *IP = nullptr; SCEVAddExpr *S = - static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); if (!S) { const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); - S = new (SCEVAllocator) SCEVAddExpr(ID.Intern(SCEVAllocator), - O, Ops.size()); + S = new (SCEVAllocator) + SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); } S->setNoWrapFlags(Flags); @@ -2889,7 +2927,7 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, // end of this file for inspiration. const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS); - if (!Mul) + if (!Mul || !Mul->hasNoUnsignedWrap()) return getUDivExpr(LHS, RHS); if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) { @@ -3385,6 +3423,10 @@ Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const { return getDataLayout().getIntPtrType(Ty); } +Type *ScalarEvolution::getWiderType(Type *T1, Type *T2) const { + return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2; +} + const SCEV *ScalarEvolution::getCouldNotCompute() { return CouldNotCompute.get(); } @@ -4409,8 +4451,7 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { return getGEPExpr(GEP, IndexExprs); } -uint32_t -ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { +uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) { if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) return C->getAPInt().countTrailingZeros(); @@ -4420,14 +4461,16 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) { uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); - return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? - getTypeSizeInBits(E->getType()) : OpRes; + return OpRes == getTypeSizeInBits(E->getOperand()->getType()) + ? getTypeSizeInBits(E->getType()) + : OpRes; } if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) { uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); - return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? - getTypeSizeInBits(E->getType()) : OpRes; + return OpRes == getTypeSizeInBits(E->getOperand()->getType()) + ? getTypeSizeInBits(E->getType()) + : OpRes; } if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) { @@ -4444,8 +4487,8 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { uint32_t BitWidth = getTypeSizeInBits(M->getType()); for (unsigned i = 1, e = M->getNumOperands(); SumOpRes != BitWidth && i != e; ++i) - SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), - BitWidth); + SumOpRes = + std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth); return SumOpRes; } @@ -4486,6 +4529,17 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { return 0; } +uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { + auto I = MinTrailingZerosCache.find(S); + if (I != MinTrailingZerosCache.end()) + return I->second; + + uint32_t Result = GetMinTrailingZerosImpl(S); + auto InsertPair = MinTrailingZerosCache.insert({S, Result}); + assert(InsertPair.second && "Should insert a new key"); + return InsertPair.first->second; +} + /// Helper method to assign a range to V from metadata present in the IR. static Optional<ConstantRange> GetRangeFromMetadata(Value *V) { if (Instruction *I = dyn_cast<Instruction>(V)) @@ -4668,6 +4722,77 @@ ScalarEvolution::getRange(const SCEV *S, return setRange(S, SignHint, ConservativeResult); } +// Given a StartRange, Step and MaxBECount for an expression compute a range of +// values that the expression can take. Initially, the expression has a value +// from StartRange and then is changed by Step up to MaxBECount times. Signed +// argument defines if we treat Step as signed or unsigned. +static ConstantRange getRangeForAffineARHelper(APInt Step, + ConstantRange StartRange, + APInt MaxBECount, + unsigned BitWidth, bool Signed) { + // If either Step or MaxBECount is 0, then the expression won't change, and we + // just need to return the initial range. + if (Step == 0 || MaxBECount == 0) + return StartRange; + + // If we don't know anything about the initial value (i.e. StartRange is + // FullRange), then we don't know anything about the final range either. + // Return FullRange. + if (StartRange.isFullSet()) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + // If Step is signed and negative, then we use its absolute value, but we also + // note that we're moving in the opposite direction. + bool Descending = Signed && Step.isNegative(); + + if (Signed) + // This is correct even for INT_SMIN. Let's look at i8 to illustrate this: + // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128. + // This equations hold true due to the well-defined wrap-around behavior of + // APInt. + Step = Step.abs(); + + // Check if Offset is more than full span of BitWidth. If it is, the + // expression is guaranteed to overflow. + if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount)) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + // Offset is by how much the expression can change. Checks above guarantee no + // overflow here. + APInt Offset = Step * MaxBECount; + + // Minimum value of the final range will match the minimal value of StartRange + // if the expression is increasing and will be decreased by Offset otherwise. + // Maximum value of the final range will match the maximal value of StartRange + // if the expression is decreasing and will be increased by Offset otherwise. + APInt StartLower = StartRange.getLower(); + APInt StartUpper = StartRange.getUpper() - 1; + APInt MovedBoundary = + Descending ? (StartLower - Offset) : (StartUpper + Offset); + + // It's possible that the new minimum/maximum value will fall into the initial + // range (due to wrap around). This means that the expression can take any + // value in this bitwidth, and we have to return full range. + if (StartRange.contains(MovedBoundary)) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + APInt NewLower, NewUpper; + if (Descending) { + NewLower = MovedBoundary; + NewUpper = StartUpper; + } else { + NewLower = StartLower; + NewUpper = MovedBoundary; + } + + // If we end up with full range, return a proper full range. + if (NewLower == NewUpper + 1) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + // No overflow detected, return [StartLower, StartUpper + Offset + 1) range. + return ConstantRange(NewLower, NewUpper + 1); +} + ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, const SCEV *Step, const SCEV *MaxBECount, @@ -4676,60 +4801,30 @@ ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, getTypeSizeInBits(MaxBECount->getType()) <= BitWidth && "Precondition!"); - ConstantRange Result(BitWidth, /* isFullSet = */ true); - - // Check for overflow. This must be done with ConstantRange arithmetic - // because we could be called from within the ScalarEvolution overflow - // checking code. - MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType()); ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); - ConstantRange ZExtMaxBECountRange = MaxBECountRange.zextOrTrunc(BitWidth * 2); + APInt MaxBECountValue = MaxBECountRange.getUnsignedMax(); + // First, consider step signed. + ConstantRange StartSRange = getSignedRange(Start); ConstantRange StepSRange = getSignedRange(Step); - ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2); - - ConstantRange StartURange = getUnsignedRange(Start); - ConstantRange EndURange = - StartURange.add(MaxBECountRange.multiply(StepSRange)); - - // Check for unsigned overflow. - ConstantRange ZExtStartURange = StartURange.zextOrTrunc(BitWidth * 2); - ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2); - if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == - ZExtEndURange) { - APInt Min = APIntOps::umin(StartURange.getUnsignedMin(), - EndURange.getUnsignedMin()); - APInt Max = APIntOps::umax(StartURange.getUnsignedMax(), - EndURange.getUnsignedMax()); - bool IsFullRange = Min.isMinValue() && Max.isMaxValue(); - if (!IsFullRange) - Result = - Result.intersectWith(ConstantRange(Min, Max + 1)); - } - ConstantRange StartSRange = getSignedRange(Start); - ConstantRange EndSRange = - StartSRange.add(MaxBECountRange.multiply(StepSRange)); - - // Check for signed overflow. This must be done with ConstantRange - // arithmetic because we could be called from within the ScalarEvolution - // overflow checking code. - ConstantRange SExtStartSRange = StartSRange.sextOrTrunc(BitWidth * 2); - ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2); - if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == - SExtEndSRange) { - APInt Min = - APIntOps::smin(StartSRange.getSignedMin(), EndSRange.getSignedMin()); - APInt Max = - APIntOps::smax(StartSRange.getSignedMax(), EndSRange.getSignedMax()); - bool IsFullRange = Min.isMinSignedValue() && Max.isMaxSignedValue(); - if (!IsFullRange) - Result = - Result.intersectWith(ConstantRange(Min, Max + 1)); - } + // If Step can be both positive and negative, we need to find ranges for the + // maximum absolute step values in both directions and union them. + ConstantRange SR = + getRangeForAffineARHelper(StepSRange.getSignedMin(), StartSRange, + MaxBECountValue, BitWidth, /* Signed = */ true); + SR = SR.unionWith(getRangeForAffineARHelper(StepSRange.getSignedMax(), + StartSRange, MaxBECountValue, + BitWidth, /* Signed = */ true)); - return Result; + // Next, consider step unsigned. + ConstantRange UR = getRangeForAffineARHelper( + getUnsignedRange(Step).getUnsignedMax(), getUnsignedRange(Start), + MaxBECountValue, BitWidth, /* Signed = */ false); + + // Finally, intersect signed and unsigned ranges. + return SR.intersectWith(UR); } ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, @@ -5148,12 +5243,27 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) { - const SCEV *MulCount = getConstant(ConstantInt::get( - getContext(), APInt::getOneBitSet(BitWidth, TZ))); + const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ)); + const SCEV *LHS = getSCEV(BO->LHS); + const SCEV *ShiftedLHS = nullptr; + if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) { + if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) { + // For an expression like (x * 8) & 8, simplify the multiply. + unsigned MulZeros = OpC->getAPInt().countTrailingZeros(); + unsigned GCD = std::min(MulZeros, TZ); + APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD); + SmallVector<const SCEV*, 4> MulOps; + MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD))); + MulOps.append(LHSMul->op_begin() + 1, LHSMul->op_end()); + auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags()); + ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt)); + } + } + if (!ShiftedLHS) + ShiftedLHS = getUDivExpr(LHS, MulCount); return getMulExpr( getZeroExtendExpr( - getTruncateExpr( - getUDivExactExpr(getSCEV(BO->LHS), MulCount), + getTruncateExpr(ShiftedLHS, IntegerType::get(getContext(), BitWidth - LZ - TZ)), BO->LHS->getType()), MulCount); @@ -5211,7 +5321,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // If C is a low-bits mask, the zero extend is serving to // mask off the high bits. Complement the operand and // re-apply the zext. - if (APIntOps::isMask(Z0TySize, CI->getValue())) + if (CI->getValue().isMask(Z0TySize)) return getZeroExtendExpr(getNotSCEV(Z0), UTy); // If C is a single bit, it may be in the sign-bit position @@ -5255,28 +5365,55 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { break; case Instruction::AShr: - // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression. - if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) - if (Operator *L = dyn_cast<Operator>(BO->LHS)) - if (L->getOpcode() == Instruction::Shl && - L->getOperand(1) == BO->RHS) { - uint64_t BitWidth = getTypeSizeInBits(BO->LHS->getType()); - - // If the shift count is not less than the bitwidth, the result of - // the shift is undefined. Don't try to analyze it, because the - // resolution chosen here may differ from the resolution chosen in - // other parts of the compiler. - if (CI->getValue().uge(BitWidth)) - break; + // AShr X, C, where C is a constant. + ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS); + if (!CI) + break; + + Type *OuterTy = BO->LHS->getType(); + uint64_t BitWidth = getTypeSizeInBits(OuterTy); + // If the shift count is not less than the bitwidth, the result of + // the shift is undefined. Don't try to analyze it, because the + // resolution chosen here may differ from the resolution chosen in + // other parts of the compiler. + if (CI->getValue().uge(BitWidth)) + break; - uint64_t Amt = BitWidth - CI->getZExtValue(); - if (Amt == BitWidth) - return getSCEV(L->getOperand(0)); // shift by zero --> noop + if (CI->isNullValue()) + return getSCEV(BO->LHS); // shift by zero --> noop + + uint64_t AShrAmt = CI->getZExtValue(); + Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt); + + Operator *L = dyn_cast<Operator>(BO->LHS); + if (L && L->getOpcode() == Instruction::Shl) { + // X = Shl A, n + // Y = AShr X, m + // Both n and m are constant. + + const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0)); + if (L->getOperand(1) == BO->RHS) + // For a two-shift sext-inreg, i.e. n = m, + // use sext(trunc(x)) as the SCEV expression. + return getSignExtendExpr( + getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy); + + ConstantInt *ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1)); + if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) { + uint64_t ShlAmt = ShlAmtCI->getZExtValue(); + if (ShlAmt > AShrAmt) { + // When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV + // expression. We already checked that ShlAmt < BitWidth, so + // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as + // ShlAmt - AShrAmt < Amt. + APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt, + ShlAmt - AShrAmt); return getSignExtendExpr( - getTruncateExpr(getSCEV(L->getOperand(0)), - IntegerType::get(getContext(), Amt)), - BO->LHS->getType()); + getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy), + getConstant(Mul)), OuterTy); } + } + } break; } } @@ -5348,7 +5485,7 @@ static unsigned getConstantTripCount(const SCEVConstant *ExitCount) { return ((unsigned)ExitConst->getZExtValue()) + 1; } -unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) { +unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L) { if (BasicBlock *ExitingBB = L->getExitingBlock()) return getSmallConstantTripCount(L, ExitingBB); @@ -5356,7 +5493,7 @@ unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) { return 0; } -unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L, +unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L, BasicBlock *ExitingBlock) { assert(ExitingBlock && "Must pass a non-null exiting block!"); assert(L->isLoopExiting(ExitingBlock) && @@ -5366,13 +5503,13 @@ unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L, return getConstantTripCount(ExitCount); } -unsigned ScalarEvolution::getSmallConstantMaxTripCount(Loop *L) { +unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) { const auto *MaxExitCount = dyn_cast<SCEVConstant>(getMaxBackedgeTakenCount(L)); return getConstantTripCount(MaxExitCount); } -unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) { +unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) { if (BasicBlock *ExitingBB = L->getExitingBlock()) return getSmallConstantTripMultiple(L, ExitingBB); @@ -5393,7 +5530,7 @@ unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) { /// As explained in the comments for getSmallConstantTripCount, this assumes /// that control exits the loop via ExitingBlock. unsigned -ScalarEvolution::getSmallConstantTripMultiple(Loop *L, +ScalarEvolution::getSmallConstantTripMultiple(const Loop *L, BasicBlock *ExitingBlock) { assert(ExitingBlock && "Must pass a non-null exiting block!"); assert(L->isLoopExiting(ExitingBlock) && @@ -5403,17 +5540,16 @@ ScalarEvolution::getSmallConstantTripMultiple(Loop *L, return 1; // Get the trip count from the BE count by adding 1. - const SCEV *TCMul = getAddExpr(ExitCount, getOne(ExitCount->getType())); - // FIXME: SCEV distributes multiplication as V1*C1 + V2*C1. We could attempt - // to factor simple cases. - if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(TCMul)) - TCMul = Mul->getOperand(0); - - const SCEVConstant *MulC = dyn_cast<SCEVConstant>(TCMul); - if (!MulC) - return 1; + const SCEV *TCExpr = getAddExpr(ExitCount, getOne(ExitCount->getType())); - ConstantInt *Result = MulC->getValue(); + const SCEVConstant *TC = dyn_cast<SCEVConstant>(TCExpr); + if (!TC) + // Attempt to factor more general cases. Returns the greatest power of + // two divisor. If overflow happens, the trip count expression is still + // divisible by the greatest power of 2 divisor returned. + return 1U << std::min((uint32_t)31, GetMinTrailingZeros(TCExpr)); + + ConstantInt *Result = TC->getValue(); // Guard against huge trip counts (this requires checking // for zero to handle the case where the trip count == -1 and the @@ -5428,7 +5564,8 @@ ScalarEvolution::getSmallConstantTripMultiple(Loop *L, /// Get the expression for the number of loop iterations for which this loop is /// guaranteed not to exit via ExitingBlock. Otherwise return /// SCEVCouldNotCompute. -const SCEV *ScalarEvolution::getExitCount(Loop *L, BasicBlock *ExitingBlock) { +const SCEV *ScalarEvolution::getExitCount(const Loop *L, + BasicBlock *ExitingBlock) { return getBackedgeTakenInfo(L).getExact(ExitingBlock, this); } @@ -6408,7 +6545,10 @@ static bool canConstantEvolve(Instruction *I, const Loop *L) { /// recursing through each instruction operand until reaching a loop header phi. static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, - DenseMap<Instruction *, PHINode *> &PHIMap) { + DenseMap<Instruction *, PHINode *> &PHIMap, + unsigned Depth) { + if (Depth > MaxConstantEvolvingDepth) + return nullptr; // Otherwise, we can evaluate this instruction if all of its operands are // constant or derived from a PHI node themselves. @@ -6428,7 +6568,7 @@ getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, if (!P) { // Recurse and memoize the results, whether a phi is found or not. // This recursive call invalidates pointers into PHIMap. - P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap); + P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1); PHIMap[OpInst] = P; } if (!P) @@ -6455,7 +6595,7 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { // Record non-constant instructions contained by the loop. DenseMap<Instruction *, PHINode *> PHIMap; - return getConstantEvolvingPHIOperands(I, L, PHIMap); + return getConstantEvolvingPHIOperands(I, L, PHIMap, 0); } /// EvaluateExpression - Given an expression that passes the @@ -7014,10 +7154,10 @@ const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { /// A and B isn't important. /// /// If the equation does not have a solution, SCEVCouldNotCompute is returned. -static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B, +static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, ScalarEvolution &SE) { uint32_t BW = A.getBitWidth(); - assert(BW == B.getBitWidth() && "Bit widths must be the same."); + assert(BW == SE.getTypeSizeInBits(B->getType())); assert(A != 0 && "A must be non-zero."); // 1. D = gcd(A, N) @@ -7031,7 +7171,7 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B, // // B is divisible by D if and only if the multiplicity of prime factor 2 for B // is not less than multiplicity of this prime factor for D. - if (B.countTrailingZeros() < Mult2) + if (SE.GetMinTrailingZeros(B) < Mult2) return SE.getCouldNotCompute(); // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic @@ -7049,9 +7189,8 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B, // I * (B / D) mod (N / D) // To simplify the computation, we factor out the divide by D: // (I * B mod N) / D - APInt Result = (I * B).lshr(Mult2); - - return SE.getConstant(Result); + const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2)); + return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D); } /// Find the roots of the quadratic equation for the given quadratic chrec @@ -7082,7 +7221,7 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C // The B coefficient is M-N/2 APInt B(M); - B -= sdiv(N,Two); + B -= N.sdiv(Two); // The A coefficient is N/2 APInt A(N.sdiv(Two)); @@ -7233,62 +7372,6 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, return ExitLimit(Distance, getConstant(MaxBECount), false, Predicates); } - // As a special case, handle the instance where Step is a positive power of - // two. In this case, determining whether Step divides Distance evenly can be - // done by counting and comparing the number of trailing zeros of Step and - // Distance. - if (!CountDown) { - const APInt &StepV = StepC->getAPInt(); - // StepV.isPowerOf2() returns true if StepV is an positive power of two. It - // also returns true if StepV is maximally negative (eg, INT_MIN), but that - // case is not handled as this code is guarded by !CountDown. - if (StepV.isPowerOf2() && - GetMinTrailingZeros(Distance) >= StepV.countTrailingZeros()) { - // Here we've constrained the equation to be of the form - // - // 2^(N + k) * Distance' = (StepV == 2^N) * X (mod 2^W) ... (0) - // - // where we're operating on a W bit wide integer domain and k is - // non-negative. The smallest unsigned solution for X is the trip count. - // - // (0) is equivalent to: - // - // 2^(N + k) * Distance' - 2^N * X = L * 2^W - // <=> 2^N(2^k * Distance' - X) = L * 2^(W - N) * 2^N - // <=> 2^k * Distance' - X = L * 2^(W - N) - // <=> 2^k * Distance' = L * 2^(W - N) + X ... (1) - // - // The smallest X satisfying (1) is unsigned remainder of dividing the LHS - // by 2^(W - N). - // - // <=> X = 2^k * Distance' URem 2^(W - N) ... (2) - // - // E.g. say we're solving - // - // 2 * Val = 2 * X (in i8) ... (3) - // - // then from (2), we get X = Val URem i8 128 (k = 0 in this case). - // - // Note: It is tempting to solve (3) by setting X = Val, but Val is not - // necessarily the smallest unsigned value of X that satisfies (3). - // E.g. if Val is i8 -127 then the smallest value of X that satisfies (3) - // is i8 1, not i8 -127 - - const auto *ModuloResult = getUDivExactExpr(Distance, Step); - - // Since SCEV does not have a URem node, we construct one using a truncate - // and a zero extend. - - unsigned NarrowWidth = StepV.getBitWidth() - StepV.countTrailingZeros(); - auto *NarrowTy = IntegerType::get(getContext(), NarrowWidth); - auto *WideTy = Distance->getType(); - - const SCEV *Limit = - getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy); - return ExitLimit(Limit, Limit, false, Predicates); - } - } - // If the condition controls loop exit (the loop exits only if the expression // is true) and the addition is no-wrap we can use unsigned divide to // compute the backedge count. In this case, the step may not divide the @@ -7301,13 +7384,10 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, return ExitLimit(Exact, Exact, false, Predicates); } - // Then, try to solve the above equation provided that Start is constant. - if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) { - const SCEV *E = SolveLinEquationWithOverflow( - StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this); - return ExitLimit(E, E, false, Predicates); - } - return getCouldNotCompute(); + // Solve the general equation. + const SCEV *E = SolveLinEquationWithOverflow( + StepC->getAPInt(), getNegativeSCEV(Start), *this); + return ExitLimit(E, E, false, Predicates); } ScalarEvolution::ExitLimit @@ -8488,19 +8568,161 @@ static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, llvm_unreachable("covered switch fell through?!"); } +bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS, + unsigned Depth) { + assert(getTypeSizeInBits(LHS->getType()) == + getTypeSizeInBits(RHS->getType()) && + "LHS and RHS have different sizes?"); + assert(getTypeSizeInBits(FoundLHS->getType()) == + getTypeSizeInBits(FoundRHS->getType()) && + "FoundLHS and FoundRHS have different sizes?"); + // We want to avoid hurting the compile time with analysis of too big trees. + if (Depth > MaxSCEVOperationsImplicationDepth) + return false; + // We only want to work with ICMP_SGT comparison so far. + // TODO: Extend to ICMP_UGT? + if (Pred == ICmpInst::ICMP_SLT) { + Pred = ICmpInst::ICMP_SGT; + std::swap(LHS, RHS); + std::swap(FoundLHS, FoundRHS); + } + if (Pred != ICmpInst::ICMP_SGT) + return false; + + auto GetOpFromSExt = [&](const SCEV *S) { + if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S)) + return Ext->getOperand(); + // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off + // the constant in some cases. + return S; + }; + + // Acquire values from extensions. + auto *OrigFoundLHS = FoundLHS; + LHS = GetOpFromSExt(LHS); + FoundLHS = GetOpFromSExt(FoundLHS); + + // Is the SGT predicate can be proved trivially or using the found context. + auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) { + return isKnownViaSimpleReasoning(ICmpInst::ICMP_SGT, S1, S2) || + isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS, + FoundRHS, Depth + 1); + }; + + if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) { + // We want to avoid creation of any new non-constant SCEV. Since we are + // going to compare the operands to RHS, we should be certain that we don't + // need any size extensions for this. So let's decline all cases when the + // sizes of types of LHS and RHS do not match. + // TODO: Maybe try to get RHS from sext to catch more cases? + if (getTypeSizeInBits(LHS->getType()) != getTypeSizeInBits(RHS->getType())) + return false; + + // Should not overflow. + if (!LHSAddExpr->hasNoSignedWrap()) + return false; + + auto *LL = LHSAddExpr->getOperand(0); + auto *LR = LHSAddExpr->getOperand(1); + auto *MinusOne = getNegativeSCEV(getOne(RHS->getType())); + + // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context. + auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) { + return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS); + }; + // Try to prove the following rule: + // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS). + // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS). + if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL)) + return true; + } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) { + Value *LL, *LR; + // FIXME: Once we have SDiv implemented, we can get rid of this matching. + using namespace llvm::PatternMatch; + if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) { + // Rules for division. + // We are going to perform some comparisons with Denominator and its + // derivative expressions. In general case, creating a SCEV for it may + // lead to a complex analysis of the entire graph, and in particular it + // can request trip count recalculation for the same loop. This would + // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid + // this, we only want to create SCEVs that are constants in this section. + // So we bail if Denominator is not a constant. + if (!isa<ConstantInt>(LR)) + return false; + + auto *Denominator = cast<SCEVConstant>(getSCEV(LR)); + + // We want to make sure that LHS = FoundLHS / Denominator. If it is so, + // then a SCEV for the numerator already exists and matches with FoundLHS. + auto *Numerator = getExistingSCEV(LL); + if (!Numerator || Numerator->getType() != FoundLHS->getType()) + return false; + + // Make sure that the numerator matches with FoundLHS and the denominator + // is positive. + if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator)) + return false; + + auto *DTy = Denominator->getType(); + auto *FRHSTy = FoundRHS->getType(); + if (DTy->isPointerTy() != FRHSTy->isPointerTy()) + // One of types is a pointer and another one is not. We cannot extend + // them properly to a wider type, so let us just reject this case. + // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help + // to avoid this check. + return false; + + // Given that: + // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0. + auto *WTy = getWiderType(DTy, FRHSTy); + auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy); + auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy); + + // Try to prove the following rule: + // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS). + // For example, given that FoundLHS > 2. It means that FoundLHS is at + // least 3. If we divide it by Denominator < 4, we will have at least 1. + auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2)); + if (isKnownNonPositive(RHS) && + IsSGTViaContext(FoundRHSExt, DenomMinusTwo)) + return true; + + // Try to prove the following rule: + // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS). + // For example, given that FoundLHS > -3. Then FoundLHS is at least -2. + // If we divide it by Denominator > 2, then: + // 1. If FoundLHS is negative, then the result is 0. + // 2. If FoundLHS is non-negative, then the result is non-negative. + // Anyways, the result is non-negative. + auto *MinusOne = getNegativeSCEV(getOne(WTy)); + auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt); + if (isKnownNegative(RHS) && + IsSGTViaContext(FoundRHSExt, NegDenomMinusOne)) + return true; + } + } + + return false; +} + +bool +ScalarEvolution::isKnownViaSimpleReasoning(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || + IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || + IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || + isKnownPredicateViaNoOverflow(Pred, LHS, RHS); +} + bool ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { - auto IsKnownPredicateFull = - [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { - return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || - IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || - IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || - isKnownPredicateViaNoOverflow(Pred, LHS, RHS); - }; - switch (Pred) { default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); case ICmpInst::ICMP_EQ: @@ -8510,30 +8732,34 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, break; case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - if (IsKnownPredicateFull(ICmpInst::ICMP_SLE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_SGE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - if (IsKnownPredicateFull(ICmpInst::ICMP_SGE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_SLE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: - if (IsKnownPredicateFull(ICmpInst::ICMP_ULE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_UGE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: - if (IsKnownPredicateFull(ICmpInst::ICMP_UGE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_ULE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS)) return true; break; } + // Maybe it can be proved via operations? + if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + return false; } @@ -9524,6 +9750,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) ValueExprMap(std::move(Arg.ValueExprMap)), PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)), WalkingBEDominatingConds(false), ProvingSplitPredicate(false), + MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)), BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)), PredicatedBackedgeTakenCounts( std::move(Arg.PredicatedBackedgeTakenCounts)), @@ -9621,6 +9848,13 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, OS << "Unpredictable predicated backedge-taken count. "; } OS << "\n"; + + if (SE->hasLoopInvariantBackedgeTakenCount(L)) { + OS << "Loop "; + L->getHeader()->printAsOperand(OS, /*PrintType=*/false); + OS << ": "; + OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n"; + } } static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) { @@ -9929,6 +10163,7 @@ void ScalarEvolution::forgetMemoizedResults(const SCEV *S) { SignedRanges.erase(S); ExprValueMap.erase(S); HasRecMap.erase(S); + MinTrailingZerosCache.erase(S); auto RemoveSCEVFromBackedgeMap = [S, this](DenseMap<const Loop *, BackedgeTakenInfo> &Map) { |