diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2024-07-27 23:34:35 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2024-10-23 18:26:01 +0000 |
commit | 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583 (patch) | |
tree | 6cf5ab1f05330c6773b1f3f64799d56a9c7a1faa /contrib/llvm-project/llvm/lib/IR/ProfDataUtils.cpp | |
parent | 6b9f7133aba44189d9625c352bc2c2a59baf18ef (diff) | |
parent | ac9a064cb179f3425b310fa2847f8764ac970a4d (diff) |
Diffstat (limited to 'contrib/llvm-project/llvm/lib/IR/ProfDataUtils.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/IR/ProfDataUtils.cpp | 160 |
1 files changed, 133 insertions, 27 deletions
diff --git a/contrib/llvm-project/llvm/lib/IR/ProfDataUtils.cpp b/contrib/llvm-project/llvm/lib/IR/ProfDataUtils.cpp index 29536b0b090c..992ce34e0003 100644 --- a/contrib/llvm-project/llvm/lib/IR/ProfDataUtils.cpp +++ b/contrib/llvm-project/llvm/lib/IR/ProfDataUtils.cpp @@ -19,6 +19,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/CommandLine.h" @@ -40,12 +41,12 @@ namespace { // We maintain some constants here to ensure that we access the branch weights // correctly, and can change the behavior in the future if the layout changes -// The index at which the weights vector starts -constexpr unsigned WeightsIdx = 1; - // the minimum number of operands for MD_prof nodes with branch weights constexpr unsigned MinBWOps = 3; +// the minimum number of operands for MD_prof nodes with value profiles +constexpr unsigned MinVPOps = 5; + // We may want to add support for other MD_prof types, so provide an abstraction // for checking the metadata type. bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { @@ -62,7 +63,28 @@ bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { if (!ProfDataName) return false; - return ProfDataName->getString().equals(Name); + return ProfDataName->getString() == Name; +} + +template <typename T, + typename = typename std::enable_if<std::is_arithmetic_v<T>>> +static void extractFromBranchWeightMD(const MDNode *ProfileData, + SmallVectorImpl<T> &Weights) { + assert(isBranchWeightMD(ProfileData) && "wrong metadata"); + + unsigned NOps = ProfileData->getNumOperands(); + unsigned WeightsIdx = getBranchWeightOffset(ProfileData); + assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); + Weights.resize(NOps - WeightsIdx); + + for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { + ConstantInt *Weight = + mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); + assert(Weight && "Malformed branch_weight in MD_prof node"); + assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) && + "Too many bits for MD_prof branch_weight"); + Weights[Idx - WeightsIdx] = Weight->getZExtValue(); + } } } // namespace @@ -70,22 +92,60 @@ bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { namespace llvm { bool hasProfMD(const Instruction &I) { - return nullptr != I.getMetadata(LLVMContext::MD_prof); + return I.hasMetadata(LLVMContext::MD_prof); } bool isBranchWeightMD(const MDNode *ProfileData) { return isTargetMD(ProfileData, "branch_weights", MinBWOps); } +bool isValueProfileMD(const MDNode *ProfileData) { + return isTargetMD(ProfileData, "VP", MinVPOps); +} + bool hasBranchWeightMD(const Instruction &I) { auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); return isBranchWeightMD(ProfileData); } +bool hasCountTypeMD(const Instruction &I) { + auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); + // Value profiles record count-type information. + if (isValueProfileMD(ProfileData)) + return true; + // Conservatively assume non CallBase instruction only get taken/not-taken + // branch probability, so not interpret them as count. + return isa<CallBase>(I) && !isBranchWeightMD(ProfileData); +} + bool hasValidBranchWeightMD(const Instruction &I) { return getValidBranchWeightMDNode(I); } +bool hasBranchWeightOrigin(const Instruction &I) { + auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); + return hasBranchWeightOrigin(ProfileData); +} + +bool hasBranchWeightOrigin(const MDNode *ProfileData) { + if (!isBranchWeightMD(ProfileData)) + return false; + auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1)); + // NOTE: if we ever have more types of branch weight provenance, + // we need to check the string value is "expected". For now, we + // supply a more generic API, and avoid the spurious comparisons. + assert(ProfDataName == nullptr || ProfDataName->getString() == "expected"); + return ProfDataName != nullptr; +} + +unsigned getBranchWeightOffset(const MDNode *ProfileData) { + return hasBranchWeightOrigin(ProfileData) ? 2 : 1; +} + +unsigned getNumBranchWeights(const MDNode &ProfileData) { + return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData); +} + MDNode *getBranchWeightMDNode(const Instruction &I) { auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); if (!isBranchWeightMD(ProfileData)) @@ -95,27 +155,19 @@ MDNode *getBranchWeightMDNode(const Instruction &I) { MDNode *getValidBranchWeightMDNode(const Instruction &I) { auto *ProfileData = getBranchWeightMDNode(I); - if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors()) + if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors()) return ProfileData; return nullptr; } -void extractFromBranchWeightMD(const MDNode *ProfileData, - SmallVectorImpl<uint32_t> &Weights) { - assert(isBranchWeightMD(ProfileData) && "wrong metadata"); - - unsigned NOps = ProfileData->getNumOperands(); - assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); - Weights.resize(NOps - WeightsIdx); +void extractFromBranchWeightMD32(const MDNode *ProfileData, + SmallVectorImpl<uint32_t> &Weights) { + extractFromBranchWeightMD(ProfileData, Weights); +} - for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { - ConstantInt *Weight = - mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); - assert(Weight && "Malformed branch_weight in MD_prof node"); - assert(Weight->getValue().getActiveBits() <= 32 && - "Too many bits for uint32_t"); - Weights[Idx - WeightsIdx] = Weight->getZExtValue(); - } +void extractFromBranchWeightMD64(const MDNode *ProfileData, + SmallVectorImpl<uint64_t> &Weights) { + extractFromBranchWeightMD(ProfileData, Weights); } bool extractBranchWeights(const MDNode *ProfileData, @@ -161,8 +213,9 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) { if (!ProfDataName) return false; - if (ProfDataName->getString().equals("branch_weights")) { - for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) { + if (ProfDataName->getString() == "branch_weights") { + unsigned Offset = getBranchWeightOffset(ProfileData); + for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) { auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); assert(V && "Malformed branch_weight in MD_prof node"); TotalVal += V->getValue().getZExtValue(); @@ -170,8 +223,7 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) { return true; } - if (ProfDataName->getString().equals("VP") && - ProfileData->getNumOperands() > 3) { + if (ProfDataName->getString() == "VP" && ProfileData->getNumOperands() > 3) { TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2)) ->getValue() .getZExtValue(); @@ -184,10 +236,64 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) { return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal); } -void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) { +void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights, + bool IsExpected) { MDBuilder MDB(I.getContext()); - MDNode *BranchWeights = MDB.createBranchWeights(Weights); + MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected); I.setMetadata(LLVMContext::MD_prof, BranchWeights); } +void scaleProfData(Instruction &I, uint64_t S, uint64_t T) { + assert(T != 0 && "Caller should guarantee"); + auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); + if (ProfileData == nullptr) + return; + + auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0)); + if (!ProfDataName || (ProfDataName->getString() != "branch_weights" && + ProfDataName->getString() != "VP")) + return; + + if (!hasCountTypeMD(I)) + return; + + LLVMContext &C = I.getContext(); + + MDBuilder MDB(C); + SmallVector<Metadata *, 3> Vals; + Vals.push_back(ProfileData->getOperand(0)); + APInt APS(128, S), APT(128, T); + if (ProfDataName->getString() == "branch_weights" && + ProfileData->getNumOperands() > 0) { + // Using APInt::div may be expensive, but most cases should fit 64 bits. + APInt Val(128, + mdconst::dyn_extract<ConstantInt>( + ProfileData->getOperand(getBranchWeightOffset(ProfileData))) + ->getValue() + .getZExtValue()); + Val *= APS; + Vals.push_back(MDB.createConstant(ConstantInt::get( + Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX)))); + } else if (ProfDataName->getString() == "VP") + for (unsigned i = 1; i < ProfileData->getNumOperands(); i += 2) { + // The first value is the key of the value profile, which will not change. + Vals.push_back(ProfileData->getOperand(i)); + uint64_t Count = + mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i + 1)) + ->getValue() + .getZExtValue(); + // Don't scale the magic number. + if (Count == NOMORE_ICP_MAGICNUM) { + Vals.push_back(ProfileData->getOperand(i + 1)); + continue; + } + // Using APInt::div may be expensive, but most cases should fit 64 bits. + APInt Val(128, Count); + Val *= APS; + Vals.push_back(MDB.createConstant(ConstantInt::get( + Type::getInt64Ty(C), Val.udiv(APT).getLimitedValue()))); + } + I.setMetadata(LLVMContext::MD_prof, MDNode::get(C, Vals)); +} + } // namespace llvm |