aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp')
-rw-r--r--llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp188
1 files changed, 79 insertions, 109 deletions
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 3978e1e29825..a042146d7ace 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -747,9 +747,8 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
// so that pointer operands are inserted first, which the code below relies on
// to form more involved GEPs.
SmallVector<std::pair<const Loop *, const SCEV *>, 8> OpsAndLoops;
- for (std::reverse_iterator<SCEVAddExpr::op_iterator> I(S->op_end()),
- E(S->op_begin()); I != E; ++I)
- OpsAndLoops.push_back(std::make_pair(getRelevantLoop(*I), *I));
+ for (const SCEV *Op : reverse(S->operands()))
+ OpsAndLoops.push_back(std::make_pair(getRelevantLoop(Op), Op));
// Sort by loop. Use a stable sort so that constants follow non-constants and
// pointer operands precede non-pointer operands.
@@ -765,7 +764,11 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
// This is the first operand. Just expand it.
Sum = expand(Op);
++I;
- } else if (PointerType *PTy = dyn_cast<PointerType>(Sum->getType())) {
+ continue;
+ }
+
+ assert(!Op->getType()->isPointerTy() && "Only first op can be pointer");
+ if (PointerType *PTy = dyn_cast<PointerType>(Sum->getType())) {
// The running sum expression is a pointer. Try to form a getelementptr
// at this level with that as the base.
SmallVector<const SCEV *, 4> NewOps;
@@ -779,16 +782,6 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
NewOps.push_back(X);
}
Sum = expandAddToGEP(NewOps.begin(), NewOps.end(), PTy, Ty, Sum);
- } else if (PointerType *PTy = dyn_cast<PointerType>(Op->getType())) {
- // The running sum is an integer, and there's a pointer at this level.
- // Try to form a getelementptr. If the running sum is instructions,
- // use a SCEVUnknown to avoid re-analyzing them.
- SmallVector<const SCEV *, 4> NewOps;
- NewOps.push_back(isa<Instruction>(Sum) ? SE.getUnknown(Sum) :
- SE.getSCEV(Sum));
- for (++I; I != E && I->first == CurLoop; ++I)
- NewOps.push_back(I->second);
- Sum = expandAddToGEP(NewOps.begin(), NewOps.end(), PTy, Ty, expand(Op));
} else if (Op->isNonConstantNegative()) {
// Instead of doing a negate and add, just do a subtract.
Value *W = expandCodeForImpl(SE.getNegativeSCEV(Op), Ty, false);
@@ -817,9 +810,8 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) {
// Collect all the mul operands in a loop, along with their associated loops.
// Iterate in reverse so that constants are emitted last, all else equal.
SmallVector<std::pair<const Loop *, const SCEV *>, 8> OpsAndLoops;
- for (std::reverse_iterator<SCEVMulExpr::op_iterator> I(S->op_end()),
- E(S->op_begin()); I != E; ++I)
- OpsAndLoops.push_back(std::make_pair(getRelevantLoop(*I), *I));
+ for (const SCEV *Op : reverse(S->operands()))
+ OpsAndLoops.push_back(std::make_pair(getRelevantLoop(Op), Op));
// Sort by loop. Use a stable sort so that constants follow non-constants.
llvm::stable_sort(OpsAndLoops, LoopCompare(SE.DT));
@@ -923,28 +915,6 @@ Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) {
/*IsSafeToHoist*/ SE.isKnownNonZero(S->getRHS()));
}
-/// Move parts of Base into Rest to leave Base with the minimal
-/// expression that provides a pointer operand suitable for a
-/// GEP expansion.
-static void ExposePointerBase(const SCEV *&Base, const SCEV *&Rest,
- ScalarEvolution &SE) {
- while (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(Base)) {
- Base = A->getStart();
- Rest = SE.getAddExpr(Rest,
- SE.getAddRecExpr(SE.getConstant(A->getType(), 0),
- A->getStepRecurrence(SE),
- A->getLoop(),
- A->getNoWrapFlags(SCEV::FlagNW)));
- }
- if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(Base)) {
- Base = A->getOperand(A->getNumOperands()-1);
- SmallVector<const SCEV *, 8> NewAddOps(A->operands());
- NewAddOps.back() = Rest;
- Rest = SE.getAddExpr(NewAddOps);
- ExposePointerBase(Base, Rest, SE);
- }
-}
-
/// Determine if this is a well-behaved chain of instructions leading back to
/// the PHI. If so, it may be reused by expanded expressions.
bool SCEVExpander::isNormalAddRecExprPHI(PHINode *PN, Instruction *IncV,
@@ -1125,22 +1095,6 @@ Value *SCEVExpander::expandIVInc(PHINode *PN, Value *StepV, const Loop *L,
return IncV;
}
-/// Hoist the addrec instruction chain rooted in the loop phi above the
-/// position. This routine assumes that this is possible (has been checked).
-void SCEVExpander::hoistBeforePos(DominatorTree *DT, Instruction *InstToHoist,
- Instruction *Pos, PHINode *LoopPhi) {
- do {
- if (DT->dominates(InstToHoist, Pos))
- break;
- // Make sure the increment is where we want it. But don't move it
- // down past a potential existing post-inc user.
- fixupInsertPoints(InstToHoist);
- InstToHoist->moveBefore(Pos);
- Pos = InstToHoist;
- InstToHoist = cast<Instruction>(InstToHoist->getOperand(0));
- } while (InstToHoist != LoopPhi);
-}
-
/// Check whether we can cheaply express the requested SCEV in terms of
/// the available PHI SCEV by truncation and/or inversion of the step.
static bool canBeCheaplyTransformed(ScalarEvolution &SE,
@@ -1264,8 +1218,6 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
if (LSRMode) {
if (!isExpandedAddRecExprPHI(&PN, TempIncV, L))
continue;
- if (L == IVIncInsertLoop && !hoistIVInc(TempIncV, IVIncInsertPos))
- continue;
} else {
if (!isNormalAddRecExprPHI(&PN, TempIncV, L))
continue;
@@ -1293,11 +1245,6 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
}
if (AddRecPhiMatch) {
- // Potentially, move the increment. We have made sure in
- // isExpandedAddRecExprPHI or hoistIVInc that this is possible.
- if (L == IVIncInsertLoop)
- hoistBeforePos(&SE.DT, IncV, IVIncInsertPos, AddRecPhiMatch);
-
// Ok, the add recurrence looks usable.
// Remember this PHI, even in post-inc mode.
InsertedValues.insert(AddRecPhiMatch);
@@ -1597,29 +1544,17 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
// {X,+,F} --> X + {0,+,F}
if (!S->getStart()->isZero()) {
+ if (PointerType *PTy = dyn_cast<PointerType>(S->getType())) {
+ Value *StartV = expand(SE.getPointerBase(S));
+ assert(StartV->getType() == PTy && "Pointer type mismatch for GEP!");
+ return expandAddToGEP(SE.removePointerBase(S), PTy, Ty, StartV);
+ }
+
SmallVector<const SCEV *, 4> NewOps(S->operands());
NewOps[0] = SE.getConstant(Ty, 0);
const SCEV *Rest = SE.getAddRecExpr(NewOps, L,
S->getNoWrapFlags(SCEV::FlagNW));
- // Turn things like ptrtoint+arithmetic+inttoptr into GEP. See the
- // comments on expandAddToGEP for details.
- const SCEV *Base = S->getStart();
- // Dig into the expression to find the pointer base for a GEP.
- const SCEV *ExposedRest = Rest;
- ExposePointerBase(Base, ExposedRest, SE);
- // If we found a pointer, expand the AddRec with a GEP.
- if (PointerType *PTy = dyn_cast<PointerType>(Base->getType())) {
- // Make sure the Base isn't something exotic, such as a multiplied
- // or divided pointer value. In those cases, the result type isn't
- // actually a pointer type.
- if (!isa<SCEVMulExpr>(Base) && !isa<SCEVUDivExpr>(Base)) {
- Value *StartV = expand(Base);
- assert(StartV->getType() == PTy && "Pointer type mismatch for GEP!");
- return expandAddToGEP(ExposedRest, PTy, Ty, StartV);
- }
- }
-
// Just do a normal add. Pre-expand the operands to suppress folding.
//
// The LHS and RHS values are factored out of the expand call to make the
@@ -1898,6 +1833,22 @@ Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty, bool Root) {
return V;
}
+/// Check whether value has nuw/nsw/exact set but SCEV does not.
+/// TODO: In reality it is better to check the poison recursively
+/// but this is better than nothing.
+static bool SCEVLostPoisonFlags(const SCEV *S, const Instruction *I) {
+ if (isa<OverflowingBinaryOperator>(I)) {
+ if (auto *NS = dyn_cast<SCEVNAryExpr>(S)) {
+ if (I->hasNoSignedWrap() && !NS->hasNoSignedWrap())
+ return true;
+ if (I->hasNoUnsignedWrap() && !NS->hasNoUnsignedWrap())
+ return true;
+ }
+ } else if (isa<PossiblyExactOperator>(I) && I->isExact())
+ return true;
+ return false;
+}
+
ScalarEvolution::ValueOffsetPair
SCEVExpander::FindValueInExprValueMap(const SCEV *S,
const Instruction *InsertPt) {
@@ -1907,19 +1858,22 @@ SCEVExpander::FindValueInExprValueMap(const SCEV *S,
if (CanonicalMode || !SE.containsAddRecurrence(S)) {
// If S is scConstant, it may be worse to reuse an existing Value.
if (S->getSCEVType() != scConstant && Set) {
- // Choose a Value from the set which dominates the insertPt.
- // insertPt should be inside the Value's parent loop so as not to break
+ // Choose a Value from the set which dominates the InsertPt.
+ // InsertPt should be inside the Value's parent loop so as not to break
// the LCSSA form.
for (auto const &VOPair : *Set) {
Value *V = VOPair.first;
ConstantInt *Offset = VOPair.second;
- Instruction *EntInst = nullptr;
- if (V && isa<Instruction>(V) && (EntInst = cast<Instruction>(V)) &&
- S->getType() == V->getType() &&
- EntInst->getFunction() == InsertPt->getFunction() &&
+ Instruction *EntInst = dyn_cast_or_null<Instruction>(V);
+ if (!EntInst)
+ continue;
+
+ assert(EntInst->getFunction() == InsertPt->getFunction());
+ if (S->getType() == V->getType() &&
SE.DT.dominates(EntInst, InsertPt) &&
(SE.LI.getLoopFor(EntInst->getParent()) == nullptr ||
- SE.LI.getLoopFor(EntInst->getParent())->contains(InsertPt)))
+ SE.LI.getLoopFor(EntInst->getParent())->contains(InsertPt)) &&
+ !SCEVLostPoisonFlags(S, EntInst))
return {V, Offset};
}
}
@@ -2068,7 +2022,9 @@ SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT,
Phis.push_back(&PN);
if (TTI)
- llvm::sort(Phis, [](Value *LHS, Value *RHS) {
+ // Use stable_sort to preserve order of equivalent PHIs, so the order
+ // of the sorted Phis is the same from run to run on the same loop.
+ llvm::stable_sort(Phis, [](Value *LHS, Value *RHS) {
// Put pointers at the back and make sure pointer < pointer = false.
if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy())
return RHS->getType()->isIntegerTy() && !LHS->getType()->isIntegerTy();
@@ -2524,18 +2480,14 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR,
IntegerType *Ty =
IntegerType::get(Loc->getContext(), SE.getTypeSizeInBits(ARTy));
- Type *ARExpandTy = DL.isNonIntegralPointerType(ARTy) ? ARTy : Ty;
Value *StepValue = expandCodeForImpl(Step, Ty, Loc, false);
Value *NegStepValue =
expandCodeForImpl(SE.getNegativeSCEV(Step), Ty, Loc, false);
- Value *StartValue = expandCodeForImpl(
- isa<PointerType>(ARExpandTy) ? Start
- : SE.getPtrToIntExpr(Start, ARExpandTy),
- ARExpandTy, Loc, false);
+ Value *StartValue = expandCodeForImpl(Start, ARTy, Loc, false);
ConstantInt *Zero =
- ConstantInt::get(Loc->getContext(), APInt::getNullValue(DstBits));
+ ConstantInt::get(Loc->getContext(), APInt::getZero(DstBits));
Builder.SetInsertPoint(Loc);
// Compute |Step|
@@ -2544,25 +2496,33 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR,
// Get the backedge taken count and truncate or extended to the AR type.
Value *TruncTripCount = Builder.CreateZExtOrTrunc(TripCountVal, Ty);
- auto *MulF = Intrinsic::getDeclaration(Loc->getModule(),
- Intrinsic::umul_with_overflow, Ty);
// Compute |Step| * Backedge
- CallInst *Mul = Builder.CreateCall(MulF, {AbsStep, TruncTripCount}, "mul");
- Value *MulV = Builder.CreateExtractValue(Mul, 0, "mul.result");
- Value *OfMul = Builder.CreateExtractValue(Mul, 1, "mul.overflow");
+ Value *MulV, *OfMul;
+ if (Step->isOne()) {
+ // Special-case Step of one. Potentially-costly `umul_with_overflow` isn't
+ // needed, there is never an overflow, so to avoid artificially inflating
+ // the cost of the check, directly emit the optimized IR.
+ MulV = TruncTripCount;
+ OfMul = ConstantInt::getFalse(MulV->getContext());
+ } else {
+ auto *MulF = Intrinsic::getDeclaration(Loc->getModule(),
+ Intrinsic::umul_with_overflow, Ty);
+ CallInst *Mul = Builder.CreateCall(MulF, {AbsStep, TruncTripCount}, "mul");
+ MulV = Builder.CreateExtractValue(Mul, 0, "mul.result");
+ OfMul = Builder.CreateExtractValue(Mul, 1, "mul.overflow");
+ }
// Compute:
// Start + |Step| * Backedge < Start
// Start - |Step| * Backedge > Start
Value *Add = nullptr, *Sub = nullptr;
- if (PointerType *ARPtrTy = dyn_cast<PointerType>(ARExpandTy)) {
- const SCEV *MulS = SE.getSCEV(MulV);
- const SCEV *NegMulS = SE.getNegativeSCEV(MulS);
- Add = Builder.CreateBitCast(expandAddToGEP(MulS, ARPtrTy, Ty, StartValue),
- ARPtrTy);
- Sub = Builder.CreateBitCast(
- expandAddToGEP(NegMulS, ARPtrTy, Ty, StartValue), ARPtrTy);
+ if (PointerType *ARPtrTy = dyn_cast<PointerType>(ARTy)) {
+ StartValue = InsertNoopCastOfTo(
+ StartValue, Builder.getInt8PtrTy(ARPtrTy->getAddressSpace()));
+ Value *NegMulV = Builder.CreateNeg(MulV);
+ Add = Builder.CreateGEP(Builder.getInt8Ty(), StartValue, MulV);
+ Sub = Builder.CreateGEP(Builder.getInt8Ty(), StartValue, NegMulV);
} else {
Add = Builder.CreateAdd(StartValue, MulV);
Sub = Builder.CreateSub(StartValue, MulV);
@@ -2686,9 +2646,11 @@ namespace {
// perfectly reduced form, which can't be guaranteed.
struct SCEVFindUnsafe {
ScalarEvolution &SE;
+ bool CanonicalMode;
bool IsUnsafe;
- SCEVFindUnsafe(ScalarEvolution &se): SE(se), IsUnsafe(false) {}
+ SCEVFindUnsafe(ScalarEvolution &SE, bool CanonicalMode)
+ : SE(SE), CanonicalMode(CanonicalMode), IsUnsafe(false) {}
bool follow(const SCEV *S) {
if (const SCEVUDivExpr *D = dyn_cast<SCEVUDivExpr>(S)) {
@@ -2704,6 +2666,14 @@ struct SCEVFindUnsafe {
IsUnsafe = true;
return false;
}
+
+ // For non-affine addrecs or in non-canonical mode we need a preheader
+ // to insert into.
+ if (!AR->getLoop()->getLoopPreheader() &&
+ (!CanonicalMode || !AR->isAffine())) {
+ IsUnsafe = true;
+ return false;
+ }
}
return true;
}
@@ -2712,8 +2682,8 @@ struct SCEVFindUnsafe {
}
namespace llvm {
-bool isSafeToExpand(const SCEV *S, ScalarEvolution &SE) {
- SCEVFindUnsafe Search(SE);
+bool isSafeToExpand(const SCEV *S, ScalarEvolution &SE, bool CanonicalMode) {
+ SCEVFindUnsafe Search(SE, CanonicalMode);
visitAll(S, Search);
return !Search.IsUnsafe;
}