diff options
Diffstat (limited to 'llvm/lib/IR/ProfDataUtils.cpp')
| -rw-r--r-- | llvm/lib/IR/ProfDataUtils.cpp | 48 |
1 files changed, 27 insertions, 21 deletions
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp index e534368b05e4..29536b0b090c 100644 --- a/llvm/lib/IR/ProfDataUtils.cpp +++ b/llvm/lib/IR/ProfDataUtils.cpp @@ -17,6 +17,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/CommandLine.h" @@ -45,26 +46,6 @@ constexpr unsigned WeightsIdx = 1; // the minimum number of operands for MD_prof nodes with branch weights constexpr unsigned MinBWOps = 3; -bool extractWeights(const MDNode *ProfileData, - SmallVectorImpl<uint32_t> &Weights) { - // Assume preconditions are already met (i.e. this is valid metadata) - assert(ProfileData && "ProfileData was nullptr in extractWeights"); - unsigned NOps = ProfileData->getNumOperands(); - - assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); - Weights.resize(NOps - WeightsIdx); - - for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { - ConstantInt *Weight = - mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); - assert(Weight && "Malformed branch_weight in MD_prof node"); - assert(Weight->getValue().getActiveBits() <= 32 && - "Too many bits for uint32_t"); - Weights[Idx - WeightsIdx] = Weight->getZExtValue(); - } - return true; -} - // We may want to add support for other MD_prof types, so provide an abstraction // for checking the metadata type. bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { @@ -119,11 +100,30 @@ MDNode *getValidBranchWeightMDNode(const Instruction &I) { return nullptr; } +void extractFromBranchWeightMD(const MDNode *ProfileData, + SmallVectorImpl<uint32_t> &Weights) { + assert(isBranchWeightMD(ProfileData) && "wrong metadata"); + + unsigned NOps = ProfileData->getNumOperands(); + assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); + Weights.resize(NOps - WeightsIdx); + + for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { + ConstantInt *Weight = + mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); + assert(Weight && "Malformed branch_weight in MD_prof node"); + assert(Weight->getValue().getActiveBits() <= 32 && + "Too many bits for uint32_t"); + Weights[Idx - WeightsIdx] = Weight->getZExtValue(); + } +} + bool extractBranchWeights(const MDNode *ProfileData, SmallVectorImpl<uint32_t> &Weights) { if (!isBranchWeightMD(ProfileData)) return false; - return extractWeights(ProfileData, Weights); + extractFromBranchWeightMD(ProfileData, Weights); + return true; } bool extractBranchWeights(const Instruction &I, @@ -184,4 +184,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) { return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal); } +void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) { + MDBuilder MDB(I.getContext()); + MDNode *BranchWeights = MDB.createBranchWeights(Weights); + I.setMetadata(LLVMContext::MD_prof, BranchWeights); +} + } // namespace llvm |
