aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp127
1 files changed, 67 insertions, 60 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp
index 44aeb26fadf9..227de425ff85 100644
--- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp
+++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp
@@ -51,9 +51,9 @@ using namespace llvm;
namespace {
- struct IntRange {
- int64_t Low, High;
- };
+struct IntRange {
+ APInt Low, High;
+};
} // end anonymous namespace
@@ -66,8 +66,8 @@ bool IsInRanges(const IntRange &R, const std::vector<IntRange> &Ranges) {
// then check if the Low field is <= R.Low. If so, we
// have a Range that covers R.
auto I = llvm::lower_bound(
- Ranges, R, [](IntRange A, IntRange B) { return A.High < B.High; });
- return I != Ranges.end() && I->Low <= R.Low;
+ Ranges, R, [](IntRange A, IntRange B) { return A.High.slt(B.High); });
+ return I != Ranges.end() && I->Low.sle(R.Low);
}
struct CaseRange {
@@ -116,15 +116,14 @@ raw_ostream &operator<<(raw_ostream &O, const CaseVector &C) {
/// 2) Removed if subsequent incoming values now share the same case, i.e.,
/// multiple outcome edges are condensed into one. This is necessary to keep the
/// number of phi values equal to the number of branches to SuccBB.
-void FixPhis(
- BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB,
- const unsigned NumMergedCases = std::numeric_limits<unsigned>::max()) {
+void FixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB,
+ const APInt &NumMergedCases) {
for (auto &I : SuccBB->phis()) {
PHINode *PN = cast<PHINode>(&I);
// Only update the first occurrence if NewBB exists.
unsigned Idx = 0, E = PN->getNumIncomingValues();
- unsigned LocalNumMergedCases = NumMergedCases;
+ APInt LocalNumMergedCases = NumMergedCases;
for (; Idx != E && NewBB; ++Idx) {
if (PN->getIncomingBlock(Idx) == OrigBB) {
PN->setIncomingBlock(Idx, NewBB);
@@ -139,10 +138,10 @@ void FixPhis(
// Remove additional occurrences coming from condensed cases and keep the
// number of incoming values equal to the number of branches to SuccBB.
SmallVector<unsigned, 8> Indices;
- for (; LocalNumMergedCases > 0 && Idx < E; ++Idx)
+ for (; LocalNumMergedCases.ugt(0) && Idx < E; ++Idx)
if (PN->getIncomingBlock(Idx) == OrigBB) {
Indices.push_back(Idx);
- LocalNumMergedCases--;
+ LocalNumMergedCases -= 1;
}
// Remove incoming values in the reverse order to prevent invalidating
// *successive* index.
@@ -160,7 +159,7 @@ BasicBlock *NewLeafBlock(CaseRange &Leaf, Value *Val, ConstantInt *LowerBound,
BasicBlock *Default) {
Function *F = OrigBlock->getParent();
BasicBlock *NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock");
- F->getBasicBlockList().insert(++OrigBlock->getIterator(), NewLeaf);
+ F->insert(++OrigBlock->getIterator(), NewLeaf);
// Emit comparison
ICmpInst *Comp = nullptr;
@@ -209,8 +208,8 @@ BasicBlock *NewLeafBlock(CaseRange &Leaf, Value *Val, ConstantInt *LowerBound,
for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) {
PHINode *PN = cast<PHINode>(I);
// Remove all but one incoming entries from the cluster
- uint64_t Range = Leaf.High->getSExtValue() - Leaf.Low->getSExtValue();
- for (uint64_t j = 0; j < Range; ++j) {
+ APInt Range = Leaf.High->getValue() - Leaf.Low->getValue();
+ for (APInt j(Range.getBitWidth(), 0, true); j.slt(Range); ++j) {
PN->removeIncomingValue(OrigBlock);
}
@@ -241,8 +240,7 @@ BasicBlock *SwitchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound,
// emitting the code that checks if the value actually falls in the range
// because the bounds already tell us so.
if (Begin->Low == LowerBound && Begin->High == UpperBound) {
- unsigned NumMergedCases = 0;
- NumMergedCases = UpperBound->getSExtValue() - LowerBound->getSExtValue();
+ APInt NumMergedCases = UpperBound->getValue() - LowerBound->getValue();
FixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases);
return Begin->BB;
}
@@ -273,25 +271,24 @@ BasicBlock *SwitchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound,
if (!UnreachableRanges.empty()) {
// Check if the gap between LHS's highest and NewLowerBound is unreachable.
- int64_t GapLow = LHS.back().High->getSExtValue() + 1;
- int64_t GapHigh = NewLowerBound->getSExtValue() - 1;
- IntRange Gap = { GapLow, GapHigh };
- if (GapHigh >= GapLow && IsInRanges(Gap, UnreachableRanges))
+ APInt GapLow = LHS.back().High->getValue() + 1;
+ APInt GapHigh = NewLowerBound->getValue() - 1;
+ IntRange Gap = {GapLow, GapHigh};
+ if (GapHigh.sge(GapLow) && IsInRanges(Gap, UnreachableRanges))
NewUpperBound = LHS.back().High;
}
- LLVM_DEBUG(dbgs() << "LHS Bounds ==> [" << LowerBound->getSExtValue() << ", "
- << NewUpperBound->getSExtValue() << "]\n"
- << "RHS Bounds ==> [" << NewLowerBound->getSExtValue()
- << ", " << UpperBound->getSExtValue() << "]\n");
+ LLVM_DEBUG(dbgs() << "LHS Bounds ==> [" << LowerBound->getValue() << ", "
+ << NewUpperBound->getValue() << "]\n"
+ << "RHS Bounds ==> [" << NewLowerBound->getValue() << ", "
+ << UpperBound->getValue() << "]\n");
// Create a new node that checks if the value is < pivot. Go to the
// left branch if it is and right branch if not.
- Function* F = OrigBlock->getParent();
- BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock");
+ Function *F = OrigBlock->getParent();
+ BasicBlock *NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock");
- ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT,
- Val, Pivot.Low, "Pivot");
+ ICmpInst *Comp = new ICmpInst(ICmpInst::ICMP_SLT, Val, Pivot.Low, "Pivot");
BasicBlock *LBranch =
SwitchConvert(LHS.begin(), LHS.end(), LowerBound, NewUpperBound, Val,
@@ -300,8 +297,8 @@ BasicBlock *SwitchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound,
SwitchConvert(RHS.begin(), RHS.end(), NewLowerBound, UpperBound, Val,
NewNode, OrigBlock, Default, UnreachableRanges);
- F->getBasicBlockList().insert(++OrigBlock->getIterator(), NewNode);
- NewNode->getInstList().push_back(Comp);
+ F->insert(++OrigBlock->getIterator(), NewNode);
+ Comp->insertInto(NewNode, NewNode->end());
BranchInst::Create(LBranch, RBranch, Comp, NewNode);
return NewNode;
@@ -328,14 +325,15 @@ unsigned Clusterify(CaseVector &Cases, SwitchInst *SI) {
if (Cases.size() >= 2) {
CaseItr I = Cases.begin();
for (CaseItr J = std::next(I), E = Cases.end(); J != E; ++J) {
- int64_t nextValue = J->Low->getSExtValue();
- int64_t currentValue = I->High->getSExtValue();
- BasicBlock* nextBB = J->BB;
- BasicBlock* currentBB = I->BB;
+ const APInt &nextValue = J->Low->getValue();
+ const APInt &currentValue = I->High->getValue();
+ BasicBlock *nextBB = J->BB;
+ BasicBlock *currentBB = I->BB;
// If the two neighboring cases go to the same destination, merge them
// into a single case.
- assert(nextValue > currentValue && "Cases should be strictly ascending");
+ assert(nextValue.sgt(currentValue) &&
+ "Cases should be strictly ascending");
if ((nextValue == currentValue + 1) && (currentBB == nextBB)) {
I->High = J->High;
// FIXME: Combine branch weights.
@@ -356,8 +354,8 @@ void ProcessSwitchInst(SwitchInst *SI,
AssumptionCache *AC, LazyValueInfo *LVI) {
BasicBlock *OrigBlock = SI->getParent();
Function *F = OrigBlock->getParent();
- Value *Val = SI->getCondition(); // The value we are switching on...
- BasicBlock* Default = SI->getDefaultDest();
+ Value *Val = SI->getCondition(); // The value we are switching on...
+ BasicBlock *Default = SI->getDefaultDest();
// Don't handle unreachable blocks. If there are successors with phis, this
// would leave them behind with missing predecessors.
@@ -370,6 +368,12 @@ void ProcessSwitchInst(SwitchInst *SI,
// Prepare cases vector.
CaseVector Cases;
const unsigned NumSimpleCases = Clusterify(Cases, SI);
+ IntegerType *IT = cast<IntegerType>(SI->getCondition()->getType());
+ const unsigned BitWidth = IT->getBitWidth();
+ // Explictly use higher precision to prevent unsigned overflow where
+ // `UnsignedMax - 0 + 1 == 0`
+ APInt UnsignedZero(BitWidth + 1, 0);
+ APInt UnsignedMax = APInt::getMaxValue(BitWidth);
LLVM_DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size()
<< ". Total non-default cases: " << NumSimpleCases
<< "\nCase clusters: " << Cases << "\n");
@@ -378,7 +382,7 @@ void ProcessSwitchInst(SwitchInst *SI,
if (Cases.empty()) {
BranchInst::Create(Default, OrigBlock);
// Remove all the references from Default's PHIs to OrigBlock, but one.
- FixPhis(Default, OrigBlock, OrigBlock);
+ FixPhis(Default, OrigBlock, OrigBlock, UnsignedMax);
SI->eraseFromParent();
return;
}
@@ -415,8 +419,8 @@ void ProcessSwitchInst(SwitchInst *SI,
// the unlikely event that some of them survived, we just conservatively
// maintain the invariant that all the cases lie between the bounds. This
// may, however, still render the default case effectively unreachable.
- APInt Low = Cases.front().Low->getValue();
- APInt High = Cases.back().High->getValue();
+ const APInt &Low = Cases.front().Low->getValue();
+ const APInt &High = Cases.back().High->getValue();
APInt Min = APIntOps::smin(ValRange.getSignedMin(), Low);
APInt Max = APIntOps::smax(ValRange.getSignedMax(), High);
@@ -428,35 +432,38 @@ void ProcessSwitchInst(SwitchInst *SI,
std::vector<IntRange> UnreachableRanges;
if (DefaultIsUnreachableFromSwitch) {
- DenseMap<BasicBlock *, unsigned> Popularity;
- unsigned MaxPop = 0;
+ DenseMap<BasicBlock *, APInt> Popularity;
+ APInt MaxPop(UnsignedZero);
BasicBlock *PopSucc = nullptr;
- IntRange R = {std::numeric_limits<int64_t>::min(),
- std::numeric_limits<int64_t>::max()};
+ APInt SignedMax = APInt::getSignedMaxValue(BitWidth);
+ APInt SignedMin = APInt::getSignedMinValue(BitWidth);
+ IntRange R = {SignedMin, SignedMax};
UnreachableRanges.push_back(R);
for (const auto &I : Cases) {
- int64_t Low = I.Low->getSExtValue();
- int64_t High = I.High->getSExtValue();
+ const APInt &Low = I.Low->getValue();
+ const APInt &High = I.High->getValue();
IntRange &LastRange = UnreachableRanges.back();
- if (LastRange.Low == Low) {
+ if (LastRange.Low.eq(Low)) {
// There is nothing left of the previous range.
UnreachableRanges.pop_back();
} else {
// Terminate the previous range.
- assert(Low > LastRange.Low);
+ assert(Low.sgt(LastRange.Low));
LastRange.High = Low - 1;
}
- if (High != std::numeric_limits<int64_t>::max()) {
- IntRange R = { High + 1, std::numeric_limits<int64_t>::max() };
+ if (High.ne(SignedMax)) {
+ IntRange R = {High + 1, SignedMax};
UnreachableRanges.push_back(R);
}
// Count popularity.
- int64_t N = High - Low + 1;
- unsigned &Pop = Popularity[I.BB];
- if ((Pop += N) > MaxPop) {
+ assert(High.sge(Low) && "Popularity shouldn't be negative.");
+ APInt N = High.sext(BitWidth + 1) - Low.sext(BitWidth + 1) + 1;
+ // Explict insert to make sure the bitwidth of APInts match
+ APInt &Pop = Popularity.insert({I.BB, APInt(UnsignedZero)}).first->second;
+ if ((Pop += N).ugt(MaxPop)) {
MaxPop = Pop;
PopSucc = I.BB;
}
@@ -465,10 +472,10 @@ void ProcessSwitchInst(SwitchInst *SI,
/* UnreachableRanges should be sorted and the ranges non-adjacent. */
for (auto I = UnreachableRanges.begin(), E = UnreachableRanges.end();
I != E; ++I) {
- assert(I->Low <= I->High);
+ assert(I->Low.sle(I->High));
auto Next = I + 1;
if (Next != E) {
- assert(Next->Low > I->High);
+ assert(Next->Low.sgt(I->High));
}
}
#endif
@@ -481,7 +488,6 @@ void ProcessSwitchInst(SwitchInst *SI,
// Use the most popular block as the new default, reducing the number of
// cases.
- assert(MaxPop > 0 && PopSucc);
Default = PopSucc;
llvm::erase_if(Cases,
[PopSucc](const CaseRange &R) { return R.BB == PopSucc; });
@@ -492,8 +498,9 @@ void ProcessSwitchInst(SwitchInst *SI,
SI->eraseFromParent();
// As all the cases have been replaced with a single branch, only keep
// one entry in the PHI nodes.
- for (unsigned I = 0 ; I < (MaxPop - 1) ; ++I)
- PopSucc->removePredecessor(OrigBlock);
+ if (!MaxPop.isZero())
+ for (APInt I(UnsignedZero); I.ult(MaxPop - 1); ++I)
+ PopSucc->removePredecessor(OrigBlock);
return;
}
@@ -513,14 +520,14 @@ void ProcessSwitchInst(SwitchInst *SI,
// that SwitchBlock is the same as Default, under which the PHIs in Default
// are fixed inside SwitchConvert().
if (SwitchBlock != Default)
- FixPhis(Default, OrigBlock, nullptr);
+ FixPhis(Default, OrigBlock, nullptr, UnsignedMax);
// Branch to our shiny new if-then stuff...
BranchInst::Create(SwitchBlock, OrigBlock);
// We are now done with the switch instruction, delete it.
BasicBlock *OldDefault = SI->getDefaultDest();
- OrigBlock->getInstList().erase(SI);
+ SI->eraseFromParent();
// If the Default block has no more predecessors just add it to DeleteList.
if (pred_empty(OldDefault))