diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2024-01-24 19:17:23 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2024-04-06 20:13:49 +0000 |
commit | 7a6dacaca14b62ca4b74406814becb87a3fefac0 (patch) | |
tree | 273a870ac27484bb1f5ee55e7ef0dc0d061f63e7 /contrib/llvm-project/llvm/lib/CodeGen/SelectOptimize.cpp | |
parent | 46c59ea9b61755455ff6bf9f3e7b834e1af634ea (diff) | |
parent | 4df029cc74e5ec124f14a5682e44999ce4f086df (diff) |
Diffstat (limited to 'contrib/llvm-project/llvm/lib/CodeGen/SelectOptimize.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/CodeGen/SelectOptimize.cpp | 408 |
1 files changed, 288 insertions, 120 deletions
diff --git a/contrib/llvm-project/llvm/lib/CodeGen/SelectOptimize.cpp b/contrib/llvm-project/llvm/lib/CodeGen/SelectOptimize.cpp index 1316919e65da..9c720864358e 100644 --- a/contrib/llvm-project/llvm/lib/CodeGen/SelectOptimize.cpp +++ b/contrib/llvm-project/llvm/lib/CodeGen/SelectOptimize.cpp @@ -42,6 +42,7 @@ #include <stack> using namespace llvm; +using namespace llvm::PatternMatch; #define DEBUG_TYPE "select-optimize" @@ -114,12 +115,6 @@ public: PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM); bool runOnFunction(Function &F, Pass &P); -private: - // Select groups consist of consecutive select instructions with the same - // condition. - using SelectGroup = SmallVector<SelectInst *, 2>; - using SelectGroups = SmallVector<SelectGroup, 2>; - using Scaled64 = ScaledNumber<uint64_t>; struct CostInfo { @@ -129,6 +124,145 @@ private: Scaled64 NonPredCost; }; + /// SelectLike is an abstraction over SelectInst and other operations that can + /// act like selects. For example Or(Zext(icmp), X) can be treated like + /// select(icmp, X|1, X). + class SelectLike { + SelectLike(Instruction *I) : I(I) {} + + Instruction *I; + + public: + /// Match a select or select-like instruction, returning a SelectLike. + static SelectLike match(Instruction *I) { + // Select instruction are what we are usually looking for. + if (isa<SelectInst>(I)) + return SelectLike(I); + + // An Or(zext(i1 X), Y) can also be treated like a select, with condition + // C and values Y|1 and Y. + Value *X; + if (PatternMatch::match( + I, m_c_Or(m_OneUse(m_ZExt(m_Value(X))), m_Value())) && + X->getType()->isIntegerTy(1)) + return SelectLike(I); + + return SelectLike(nullptr); + } + + bool isValid() { return I; } + operator bool() { return isValid(); } + + Instruction *getI() { return I; } + const Instruction *getI() const { return I; } + + Type *getType() const { return I->getType(); } + + /// Return the condition for the SelectLike instruction. For example the + /// condition of a select or c in `or(zext(c), x)` + Value *getCondition() const { + if (auto *Sel = dyn_cast<SelectInst>(I)) + return Sel->getCondition(); + // Or(zext) case + if (auto *BO = dyn_cast<BinaryOperator>(I)) { + Value *X; + if (PatternMatch::match(BO->getOperand(0), + m_OneUse(m_ZExt(m_Value(X))))) + return X; + if (PatternMatch::match(BO->getOperand(1), + m_OneUse(m_ZExt(m_Value(X))))) + return X; + } + + llvm_unreachable("Unhandled case in getCondition"); + } + + /// Return the true value for the SelectLike instruction. Note this may not + /// exist for all SelectLike instructions. For example, for `or(zext(c), x)` + /// the true value would be `or(x,1)`. As this value does not exist, nullptr + /// is returned. + Value *getTrueValue() const { + if (auto *Sel = dyn_cast<SelectInst>(I)) + return Sel->getTrueValue(); + // Or(zext) case - The true value is Or(X), so return nullptr as the value + // does not yet exist. + if (isa<BinaryOperator>(I)) + return nullptr; + + llvm_unreachable("Unhandled case in getTrueValue"); + } + + /// Return the false value for the SelectLike instruction. For example the + /// getFalseValue of a select or `x` in `or(zext(c), x)` (which is + /// `select(c, x|1, x)`) + Value *getFalseValue() const { + if (auto *Sel = dyn_cast<SelectInst>(I)) + return Sel->getFalseValue(); + // Or(zext) case - return the operand which is not the zext. + if (auto *BO = dyn_cast<BinaryOperator>(I)) { + Value *X; + if (PatternMatch::match(BO->getOperand(0), + m_OneUse(m_ZExt(m_Value(X))))) + return BO->getOperand(1); + if (PatternMatch::match(BO->getOperand(1), + m_OneUse(m_ZExt(m_Value(X))))) + return BO->getOperand(0); + } + + llvm_unreachable("Unhandled case in getFalseValue"); + } + + /// Return the NonPredCost cost of the true op, given the costs in + /// InstCostMap. This may need to be generated for select-like instructions. + Scaled64 getTrueOpCost(DenseMap<const Instruction *, CostInfo> &InstCostMap, + const TargetTransformInfo *TTI) { + if (auto *Sel = dyn_cast<SelectInst>(I)) + if (auto *I = dyn_cast<Instruction>(Sel->getTrueValue())) + return InstCostMap.contains(I) ? InstCostMap[I].NonPredCost + : Scaled64::getZero(); + + // Or case - add the cost of an extra Or to the cost of the False case. + if (isa<BinaryOperator>(I)) + if (auto I = dyn_cast<Instruction>(getFalseValue())) + if (InstCostMap.contains(I)) { + InstructionCost OrCost = TTI->getArithmeticInstrCost( + Instruction::Or, I->getType(), TargetTransformInfo::TCK_Latency, + {TargetTransformInfo::OK_AnyValue, + TargetTransformInfo::OP_None}, + {TTI::OK_UniformConstantValue, TTI::OP_PowerOf2}); + return InstCostMap[I].NonPredCost + + Scaled64::get(*OrCost.getValue()); + } + + return Scaled64::getZero(); + } + + /// Return the NonPredCost cost of the false op, given the costs in + /// InstCostMap. This may need to be generated for select-like instructions. + Scaled64 + getFalseOpCost(DenseMap<const Instruction *, CostInfo> &InstCostMap, + const TargetTransformInfo *TTI) { + if (auto *Sel = dyn_cast<SelectInst>(I)) + if (auto *I = dyn_cast<Instruction>(Sel->getFalseValue())) + return InstCostMap.contains(I) ? InstCostMap[I].NonPredCost + : Scaled64::getZero(); + + // Or case - return the cost of the false case + if (isa<BinaryOperator>(I)) + if (auto I = dyn_cast<Instruction>(getFalseValue())) + if (InstCostMap.contains(I)) + return InstCostMap[I].NonPredCost; + + return Scaled64::getZero(); + } + }; + +private: + // Select groups consist of consecutive select instructions with the same + // condition. + using SelectGroup = SmallVector<SelectLike, 2>; + using SelectGroups = SmallVector<SelectGroup, 2>; + // Converts select instructions of a function to conditional jumps when deemed // profitable. Returns true if at least one select was converted. bool optimizeSelects(Function &F); @@ -156,12 +290,12 @@ private: // Determines if a select group should be converted to a branch (base // heuristics). - bool isConvertToBranchProfitableBase(const SmallVector<SelectInst *, 2> &ASI); + bool isConvertToBranchProfitableBase(const SelectGroup &ASI); // Returns true if there are expensive instructions in the cold value // operand's (if any) dependence slice of any of the selects of the given // group. - bool hasExpensiveColdOperand(const SmallVector<SelectInst *, 2> &ASI); + bool hasExpensiveColdOperand(const SelectGroup &ASI); // For a given source instruction, collect its backwards dependence slice // consisting of instructions exclusively computed for producing the operands @@ -170,7 +304,7 @@ private: Instruction *SI, bool ForSinking = false); // Returns true if the condition of the select is highly predictable. - bool isSelectHighlyPredictable(const SelectInst *SI); + bool isSelectHighlyPredictable(const SelectLike SI); // Loop-level checks to determine if a non-predicated version (with branches) // of the given loop is more profitable than its predicated version. @@ -183,20 +317,21 @@ private: CostInfo *LoopCost); // Returns a set of all the select instructions in the given select groups. - SmallPtrSet<const Instruction *, 2> getSIset(const SelectGroups &SIGroups); + SmallDenseMap<const Instruction *, SelectLike, 2> + getSImap(const SelectGroups &SIGroups); // Returns the latency cost of a given instruction. std::optional<uint64_t> computeInstCost(const Instruction *I); // Returns the misprediction cost of a given select when converted to branch. - Scaled64 getMispredictionCost(const SelectInst *SI, const Scaled64 CondCost); + Scaled64 getMispredictionCost(const SelectLike SI, const Scaled64 CondCost); // Returns the cost of a branch when the prediction is correct. Scaled64 getPredictedPathCost(Scaled64 TrueCost, Scaled64 FalseCost, - const SelectInst *SI); + const SelectLike SI); // Returns true if the target architecture supports lowering a given select. - bool isSelectKindSupported(SelectInst *SI); + bool isSelectKindSupported(const SelectLike SI); }; class SelectOptimize : public FunctionPass { @@ -368,15 +503,26 @@ void SelectOptimizeImpl::optimizeSelectsInnerLoops(Function &F, /// select instructions in \p Selects, look through the defining select /// instruction until the true/false value is not defined in \p Selects. static Value * -getTrueOrFalseValue(SelectInst *SI, bool isTrue, - const SmallPtrSet<const Instruction *, 2> &Selects) { +getTrueOrFalseValue(SelectOptimizeImpl::SelectLike SI, bool isTrue, + const SmallPtrSet<const Instruction *, 2> &Selects, + IRBuilder<> &IB) { Value *V = nullptr; - for (SelectInst *DefSI = SI; DefSI != nullptr && Selects.count(DefSI); + for (SelectInst *DefSI = dyn_cast<SelectInst>(SI.getI()); + DefSI != nullptr && Selects.count(DefSI); DefSI = dyn_cast<SelectInst>(V)) { - assert(DefSI->getCondition() == SI->getCondition() && + assert(DefSI->getCondition() == SI.getCondition() && "The condition of DefSI does not match with SI"); V = (isTrue ? DefSI->getTrueValue() : DefSI->getFalseValue()); } + + if (isa<BinaryOperator>(SI.getI())) { + assert(SI.getI()->getOpcode() == Instruction::Or && + "Only currently handling Or instructions."); + V = SI.getFalseValue(); + if (isTrue) + V = IB.CreateOr(V, ConstantInt::get(V->getType(), 1)); + } + assert(V && "Failed to get select true/false value"); return V; } @@ -424,20 +570,22 @@ void SelectOptimizeImpl::convertProfitableSIGroups(SelectGroups &ProfSIGroups) { SmallVector<std::stack<Instruction *>, 2> TrueSlices, FalseSlices; typedef std::stack<Instruction *>::size_type StackSizeType; StackSizeType maxTrueSliceLen = 0, maxFalseSliceLen = 0; - for (SelectInst *SI : ASI) { + for (SelectLike SI : ASI) { // For each select, compute the sinkable dependence chains of the true and // false operands. - if (auto *TI = dyn_cast<Instruction>(SI->getTrueValue())) { + if (auto *TI = dyn_cast_or_null<Instruction>(SI.getTrueValue())) { std::stack<Instruction *> TrueSlice; - getExclBackwardsSlice(TI, TrueSlice, SI, true); + getExclBackwardsSlice(TI, TrueSlice, SI.getI(), true); maxTrueSliceLen = std::max(maxTrueSliceLen, TrueSlice.size()); TrueSlices.push_back(TrueSlice); } - if (auto *FI = dyn_cast<Instruction>(SI->getFalseValue())) { - std::stack<Instruction *> FalseSlice; - getExclBackwardsSlice(FI, FalseSlice, SI, true); - maxFalseSliceLen = std::max(maxFalseSliceLen, FalseSlice.size()); - FalseSlices.push_back(FalseSlice); + if (auto *FI = dyn_cast_or_null<Instruction>(SI.getFalseValue())) { + if (isa<SelectInst>(SI.getI()) || !FI->hasOneUse()) { + std::stack<Instruction *> FalseSlice; + getExclBackwardsSlice(FI, FalseSlice, SI.getI(), true); + maxFalseSliceLen = std::max(maxFalseSliceLen, FalseSlice.size()); + FalseSlices.push_back(FalseSlice); + } } } // In the case of multiple select instructions in the same group, the order @@ -469,10 +617,10 @@ void SelectOptimizeImpl::convertProfitableSIGroups(SelectGroups &ProfSIGroups) { } // We split the block containing the select(s) into two blocks. - SelectInst *SI = ASI.front(); - SelectInst *LastSI = ASI.back(); - BasicBlock *StartBlock = SI->getParent(); - BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(LastSI)); + SelectLike SI = ASI.front(); + SelectLike LastSI = ASI.back(); + BasicBlock *StartBlock = SI.getI()->getParent(); + BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(LastSI.getI())); BasicBlock *EndBlock = StartBlock->splitBasicBlock(SplitPt, "select.end"); BFI->setBlockFreq(EndBlock, BFI->getBlockFreq(StartBlock)); // Delete the unconditional branch that was just created by the split. @@ -481,8 +629,8 @@ void SelectOptimizeImpl::convertProfitableSIGroups(SelectGroups &ProfSIGroups) { // Move any debug/pseudo instructions that were in-between the select // group to the newly-created end block. SmallVector<Instruction *, 2> DebugPseudoINS; - auto DIt = SI->getIterator(); - while (&*DIt != LastSI) { + auto DIt = SI.getI()->getIterator(); + while (&*DIt != LastSI.getI()) { if (DIt->isDebugOrPseudoInst()) DebugPseudoINS.push_back(&*DIt); DIt++; @@ -491,23 +639,41 @@ void SelectOptimizeImpl::convertProfitableSIGroups(SelectGroups &ProfSIGroups) { DI->moveBeforePreserving(&*EndBlock->getFirstInsertionPt()); } + // Duplicate implementation for DPValues, the non-instruction debug-info + // record. Helper lambda for moving DPValues to the end block. + auto TransferDPValues = [&](Instruction &I) { + for (auto &DPValue : llvm::make_early_inc_range(I.getDbgValueRange())) { + DPValue.removeFromParent(); + EndBlock->insertDPValueBefore(&DPValue, + EndBlock->getFirstInsertionPt()); + } + }; + + // Iterate over all instructions in between SI and LastSI, not including + // SI itself. These are all the variable assignments that happen "in the + // middle" of the select group. + auto R = make_range(std::next(SI.getI()->getIterator()), + std::next(LastSI.getI()->getIterator())); + llvm::for_each(R, TransferDPValues); + // These are the new basic blocks for the conditional branch. // At least one will become an actual new basic block. BasicBlock *TrueBlock = nullptr, *FalseBlock = nullptr; BranchInst *TrueBranch = nullptr, *FalseBranch = nullptr; if (!TrueSlicesInterleaved.empty()) { - TrueBlock = BasicBlock::Create(LastSI->getContext(), "select.true.sink", + TrueBlock = BasicBlock::Create(EndBlock->getContext(), "select.true.sink", EndBlock->getParent(), EndBlock); TrueBranch = BranchInst::Create(EndBlock, TrueBlock); - TrueBranch->setDebugLoc(LastSI->getDebugLoc()); + TrueBranch->setDebugLoc(LastSI.getI()->getDebugLoc()); for (Instruction *TrueInst : TrueSlicesInterleaved) TrueInst->moveBefore(TrueBranch); } if (!FalseSlicesInterleaved.empty()) { - FalseBlock = BasicBlock::Create(LastSI->getContext(), "select.false.sink", - EndBlock->getParent(), EndBlock); + FalseBlock = + BasicBlock::Create(EndBlock->getContext(), "select.false.sink", + EndBlock->getParent(), EndBlock); FalseBranch = BranchInst::Create(EndBlock, FalseBlock); - FalseBranch->setDebugLoc(LastSI->getDebugLoc()); + FalseBranch->setDebugLoc(LastSI.getI()->getDebugLoc()); for (Instruction *FalseInst : FalseSlicesInterleaved) FalseInst->moveBefore(FalseBranch); } @@ -517,10 +683,10 @@ void SelectOptimizeImpl::convertProfitableSIGroups(SelectGroups &ProfSIGroups) { assert(TrueBlock == nullptr && "Unexpected basic block transform while optimizing select"); - FalseBlock = BasicBlock::Create(SI->getContext(), "select.false", + FalseBlock = BasicBlock::Create(StartBlock->getContext(), "select.false", EndBlock->getParent(), EndBlock); auto *FalseBranch = BranchInst::Create(EndBlock, FalseBlock); - FalseBranch->setDebugLoc(SI->getDebugLoc()); + FalseBranch->setDebugLoc(SI.getI()->getDebugLoc()); } // Insert the real conditional branch based on the original condition. @@ -541,44 +707,36 @@ void SelectOptimizeImpl::convertProfitableSIGroups(SelectGroups &ProfSIGroups) { TT = TrueBlock; FT = FalseBlock; } - IRBuilder<> IB(SI); - auto *CondFr = - IB.CreateFreeze(SI->getCondition(), SI->getName() + ".frozen"); - IB.CreateCondBr(CondFr, TT, FT, SI); + IRBuilder<> IB(SI.getI()); + auto *CondFr = IB.CreateFreeze(SI.getCondition(), + SI.getCondition()->getName() + ".frozen"); SmallPtrSet<const Instruction *, 2> INS; - INS.insert(ASI.begin(), ASI.end()); + for (auto SI : ASI) + INS.insert(SI.getI()); + // Use reverse iterator because later select may use the value of the // earlier select, and we need to propagate value through earlier select // to get the PHI operand. for (auto It = ASI.rbegin(); It != ASI.rend(); ++It) { - SelectInst *SI = *It; + SelectLike SI = *It; // The select itself is replaced with a PHI Node. - PHINode *PN = PHINode::Create(SI->getType(), 2, ""); + PHINode *PN = PHINode::Create(SI.getType(), 2, ""); PN->insertBefore(EndBlock->begin()); - PN->takeName(SI); - PN->addIncoming(getTrueOrFalseValue(SI, true, INS), TrueBlock); - PN->addIncoming(getTrueOrFalseValue(SI, false, INS), FalseBlock); - PN->setDebugLoc(SI->getDebugLoc()); - - SI->replaceAllUsesWith(PN); - SI->eraseFromParent(); - INS.erase(SI); + PN->takeName(SI.getI()); + PN->addIncoming(getTrueOrFalseValue(SI, true, INS, IB), TrueBlock); + PN->addIncoming(getTrueOrFalseValue(SI, false, INS, IB), FalseBlock); + PN->setDebugLoc(SI.getI()->getDebugLoc()); + SI.getI()->replaceAllUsesWith(PN); + INS.erase(SI.getI()); ++NumSelectsConverted; } - } -} + IB.CreateCondBr(CondFr, TT, FT, SI.getI()); -static bool isSpecialSelect(SelectInst *SI) { - using namespace llvm::PatternMatch; - - // If the select is a logical-and/logical-or then it is better treated as a - // and/or by the backend. - if (match(SI, m_CombineOr(m_LogicalAnd(m_Value(), m_Value()), - m_LogicalOr(m_Value(), m_Value())))) - return true; - - return false; + // Remove the old select instructions, now that they are not longer used. + for (auto SI : ASI) + SI.getI()->eraseFromParent(); + } } void SelectOptimizeImpl::collectSelectGroups(BasicBlock &BB, @@ -586,22 +744,30 @@ void SelectOptimizeImpl::collectSelectGroups(BasicBlock &BB, BasicBlock::iterator BBIt = BB.begin(); while (BBIt != BB.end()) { Instruction *I = &*BBIt++; - if (SelectInst *SI = dyn_cast<SelectInst>(I)) { - if (isSpecialSelect(SI)) + if (SelectLike SI = SelectLike::match(I)) { + if (!TTI->shouldTreatInstructionLikeSelect(I)) continue; SelectGroup SIGroup; SIGroup.push_back(SI); while (BBIt != BB.end()) { Instruction *NI = &*BBIt; - SelectInst *NSI = dyn_cast<SelectInst>(NI); - if (NSI && SI->getCondition() == NSI->getCondition()) { + // Debug/pseudo instructions should be skipped and not prevent the + // formation of a select group. + if (NI->isDebugOrPseudoInst()) { + ++BBIt; + continue; + } + // We only allow selects in the same group, not other select-like + // instructions. + if (!isa<SelectInst>(NI)) + break; + + SelectLike NSI = SelectLike::match(NI); + if (NSI && SI.getCondition() == NSI.getCondition()) { SIGroup.push_back(NSI); - } else if (!NI->isDebugOrPseudoInst()) { - // Debug/pseudo instructions should be skipped and not prevent the - // formation of a select group. + } else break; - } ++BBIt; } @@ -655,12 +821,12 @@ void SelectOptimizeImpl::findProfitableSIGroupsInnerLoops( // Assuming infinite resources, the cost of a group of instructions is the // cost of the most expensive instruction of the group. Scaled64 SelectCost = Scaled64::getZero(), BranchCost = Scaled64::getZero(); - for (SelectInst *SI : ASI) { - SelectCost = std::max(SelectCost, InstCostMap[SI].PredCost); - BranchCost = std::max(BranchCost, InstCostMap[SI].NonPredCost); + for (SelectLike SI : ASI) { + SelectCost = std::max(SelectCost, InstCostMap[SI.getI()].PredCost); + BranchCost = std::max(BranchCost, InstCostMap[SI.getI()].NonPredCost); } if (BranchCost < SelectCost) { - OptimizationRemark OR(DEBUG_TYPE, "SelectOpti", ASI.front()); + OptimizationRemark OR(DEBUG_TYPE, "SelectOpti", ASI.front().getI()); OR << "Profitable to convert to branch (loop analysis). BranchCost=" << BranchCost.toString() << ", SelectCost=" << SelectCost.toString() << ". "; @@ -668,7 +834,8 @@ void SelectOptimizeImpl::findProfitableSIGroupsInnerLoops( ++NumSelectConvertedLoop; ProfSIGroups.push_back(ASI); } else { - OptimizationRemarkMissed ORmiss(DEBUG_TYPE, "SelectOpti", ASI.front()); + OptimizationRemarkMissed ORmiss(DEBUG_TYPE, "SelectOpti", + ASI.front().getI()); ORmiss << "Select is more profitable (loop analysis). BranchCost=" << BranchCost.toString() << ", SelectCost=" << SelectCost.toString() << ". "; @@ -678,14 +845,15 @@ void SelectOptimizeImpl::findProfitableSIGroupsInnerLoops( } bool SelectOptimizeImpl::isConvertToBranchProfitableBase( - const SmallVector<SelectInst *, 2> &ASI) { - SelectInst *SI = ASI.front(); - LLVM_DEBUG(dbgs() << "Analyzing select group containing " << *SI << "\n"); - OptimizationRemark OR(DEBUG_TYPE, "SelectOpti", SI); - OptimizationRemarkMissed ORmiss(DEBUG_TYPE, "SelectOpti", SI); + const SelectGroup &ASI) { + SelectLike SI = ASI.front(); + LLVM_DEBUG(dbgs() << "Analyzing select group containing " << SI.getI() + << "\n"); + OptimizationRemark OR(DEBUG_TYPE, "SelectOpti", SI.getI()); + OptimizationRemarkMissed ORmiss(DEBUG_TYPE, "SelectOpti", SI.getI()); // Skip cold basic blocks. Better to optimize for size for cold blocks. - if (PSI->isColdBlock(SI->getParent(), BFI)) { + if (PSI->isColdBlock(SI.getI()->getParent(), BFI)) { ++NumSelectColdBB; ORmiss << "Not converted to branch because of cold basic block. "; EmitAndPrintRemark(ORE, ORmiss); @@ -693,7 +861,7 @@ bool SelectOptimizeImpl::isConvertToBranchProfitableBase( } // If unpredictable, branch form is less profitable. - if (SI->getMetadata(LLVMContext::MD_unpredictable)) { + if (SI.getI()->getMetadata(LLVMContext::MD_unpredictable)) { ++NumSelectUnPred; ORmiss << "Not converted to branch because of unpredictable branch. "; EmitAndPrintRemark(ORE, ORmiss); @@ -728,17 +896,24 @@ static InstructionCost divideNearest(InstructionCost Numerator, return (Numerator + (Denominator / 2)) / Denominator; } -bool SelectOptimizeImpl::hasExpensiveColdOperand( - const SmallVector<SelectInst *, 2> &ASI) { +static bool extractBranchWeights(const SelectOptimizeImpl::SelectLike SI, + uint64_t &TrueVal, uint64_t &FalseVal) { + if (isa<SelectInst>(SI.getI())) + return extractBranchWeights(*SI.getI(), TrueVal, FalseVal); + return false; +} + +bool SelectOptimizeImpl::hasExpensiveColdOperand(const SelectGroup &ASI) { bool ColdOperand = false; uint64_t TrueWeight, FalseWeight, TotalWeight; - if (extractBranchWeights(*ASI.front(), TrueWeight, FalseWeight)) { + if (extractBranchWeights(ASI.front(), TrueWeight, FalseWeight)) { uint64_t MinWeight = std::min(TrueWeight, FalseWeight); TotalWeight = TrueWeight + FalseWeight; // Is there a path with frequency <ColdOperandThreshold% (default:20%) ? ColdOperand = TotalWeight * ColdOperandThreshold > 100 * MinWeight; } else if (PSI->hasProfileSummary()) { - OptimizationRemarkMissed ORmiss(DEBUG_TYPE, "SelectOpti", ASI.front()); + OptimizationRemarkMissed ORmiss(DEBUG_TYPE, "SelectOpti", + ASI.front().getI()); ORmiss << "Profile data available but missing branch-weights metadata for " "select instruction. "; EmitAndPrintRemark(ORE, ORmiss); @@ -747,19 +922,19 @@ bool SelectOptimizeImpl::hasExpensiveColdOperand( return false; // Check if the cold path's dependence slice is expensive for any of the // selects of the group. - for (SelectInst *SI : ASI) { + for (SelectLike SI : ASI) { Instruction *ColdI = nullptr; uint64_t HotWeight; if (TrueWeight < FalseWeight) { - ColdI = dyn_cast<Instruction>(SI->getTrueValue()); + ColdI = dyn_cast_or_null<Instruction>(SI.getTrueValue()); HotWeight = FalseWeight; } else { - ColdI = dyn_cast<Instruction>(SI->getFalseValue()); + ColdI = dyn_cast_or_null<Instruction>(SI.getFalseValue()); HotWeight = TrueWeight; } if (ColdI) { std::stack<Instruction *> ColdSlice; - getExclBackwardsSlice(ColdI, ColdSlice, SI); + getExclBackwardsSlice(ColdI, ColdSlice, SI.getI()); InstructionCost SliceCost = 0; while (!ColdSlice.empty()) { SliceCost += TTI->getInstructionCost(ColdSlice.top(), @@ -849,9 +1024,9 @@ void SelectOptimizeImpl::getExclBackwardsSlice(Instruction *I, } } -bool SelectOptimizeImpl::isSelectHighlyPredictable(const SelectInst *SI) { +bool SelectOptimizeImpl::isSelectHighlyPredictable(const SelectLike SI) { uint64_t TrueWeight, FalseWeight; - if (extractBranchWeights(*SI, TrueWeight, FalseWeight)) { + if (extractBranchWeights(SI, TrueWeight, FalseWeight)) { uint64_t Max = std::max(TrueWeight, FalseWeight); uint64_t Sum = TrueWeight + FalseWeight; if (Sum != 0) { @@ -937,7 +1112,7 @@ bool SelectOptimizeImpl::computeLoopCosts( DenseMap<const Instruction *, CostInfo> &InstCostMap, CostInfo *LoopCost) { LLVM_DEBUG(dbgs() << "Calculating Latency / IPredCost / INonPredCost of loop " << L->getHeader()->getName() << "\n"); - const auto &SIset = getSIset(SIGroups); + const auto &SImap = getSImap(SIGroups); // Compute instruction and loop-critical-path costs across two iterations for // both predicated and non-predicated version. const unsigned Iterations = 2; @@ -982,22 +1157,15 @@ bool SelectOptimizeImpl::computeLoopCosts( // BranchCost = PredictedPathCost + MispredictCost // PredictedPathCost = TrueOpCost * TrueProb + FalseOpCost * FalseProb // MispredictCost = max(MispredictPenalty, CondCost) * MispredictRate - if (SIset.contains(&I)) { - auto SI = cast<SelectInst>(&I); - - Scaled64 TrueOpCost = Scaled64::getZero(), - FalseOpCost = Scaled64::getZero(); - if (auto *TI = dyn_cast<Instruction>(SI->getTrueValue())) - if (InstCostMap.count(TI)) - TrueOpCost = InstCostMap[TI].NonPredCost; - if (auto *FI = dyn_cast<Instruction>(SI->getFalseValue())) - if (InstCostMap.count(FI)) - FalseOpCost = InstCostMap[FI].NonPredCost; + if (SImap.contains(&I)) { + auto SI = SImap.at(&I); + Scaled64 TrueOpCost = SI.getTrueOpCost(InstCostMap, TTI); + Scaled64 FalseOpCost = SI.getFalseOpCost(InstCostMap, TTI); Scaled64 PredictedPathCost = getPredictedPathCost(TrueOpCost, FalseOpCost, SI); Scaled64 CondCost = Scaled64::getZero(); - if (auto *CI = dyn_cast<Instruction>(SI->getCondition())) + if (auto *CI = dyn_cast<Instruction>(SI.getCondition())) if (InstCostMap.count(CI)) CondCost = InstCostMap[CI].NonPredCost; Scaled64 MispredictCost = getMispredictionCost(SI, CondCost); @@ -1019,13 +1187,13 @@ bool SelectOptimizeImpl::computeLoopCosts( return true; } -SmallPtrSet<const Instruction *, 2> -SelectOptimizeImpl::getSIset(const SelectGroups &SIGroups) { - SmallPtrSet<const Instruction *, 2> SIset; +SmallDenseMap<const Instruction *, SelectOptimizeImpl::SelectLike, 2> +SelectOptimizeImpl::getSImap(const SelectGroups &SIGroups) { + SmallDenseMap<const Instruction *, SelectLike, 2> SImap; for (const SelectGroup &ASI : SIGroups) - for (const SelectInst *SI : ASI) - SIset.insert(SI); - return SIset; + for (SelectLike SI : ASI) + SImap.try_emplace(SI.getI(), SI); + return SImap; } std::optional<uint64_t> @@ -1038,7 +1206,7 @@ SelectOptimizeImpl::computeInstCost(const Instruction *I) { } ScaledNumber<uint64_t> -SelectOptimizeImpl::getMispredictionCost(const SelectInst *SI, +SelectOptimizeImpl::getMispredictionCost(const SelectLike SI, const Scaled64 CondCost) { uint64_t MispredictPenalty = TSchedModel.getMCSchedModel()->MispredictPenalty; @@ -1065,10 +1233,10 @@ SelectOptimizeImpl::getMispredictionCost(const SelectInst *SI, // TrueCost * TrueProbability + FalseCost * FalseProbability. ScaledNumber<uint64_t> SelectOptimizeImpl::getPredictedPathCost(Scaled64 TrueCost, Scaled64 FalseCost, - const SelectInst *SI) { + const SelectLike SI) { Scaled64 PredPathCost; uint64_t TrueWeight, FalseWeight; - if (extractBranchWeights(*SI, TrueWeight, FalseWeight)) { + if (extractBranchWeights(SI, TrueWeight, FalseWeight)) { uint64_t SumWeight = TrueWeight + FalseWeight; if (SumWeight != 0) { PredPathCost = TrueCost * Scaled64::get(TrueWeight) + @@ -1085,12 +1253,12 @@ SelectOptimizeImpl::getPredictedPathCost(Scaled64 TrueCost, Scaled64 FalseCost, return PredPathCost; } -bool SelectOptimizeImpl::isSelectKindSupported(SelectInst *SI) { - bool VectorCond = !SI->getCondition()->getType()->isIntegerTy(1); +bool SelectOptimizeImpl::isSelectKindSupported(const SelectLike SI) { + bool VectorCond = !SI.getCondition()->getType()->isIntegerTy(1); if (VectorCond) return false; TargetLowering::SelectSupportKind SelectKind; - if (SI->getType()->isVectorTy()) + if (SI.getType()->isVectorTy()) SelectKind = TargetLowering::ScalarCondVectorVal; else SelectKind = TargetLowering::ScalarValSelect; |