aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/IR/ProfDataUtils.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2024-07-27 23:34:35 +0000
committerDimitry Andric <dim@FreeBSD.org>2024-10-23 18:26:01 +0000
commit0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583 (patch)
tree6cf5ab1f05330c6773b1f3f64799d56a9c7a1faa /contrib/llvm-project/llvm/lib/IR/ProfDataUtils.cpp
parent6b9f7133aba44189d9625c352bc2c2a59baf18ef (diff)
parentac9a064cb179f3425b310fa2847f8764ac970a4d (diff)
Diffstat (limited to 'contrib/llvm-project/llvm/lib/IR/ProfDataUtils.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/IR/ProfDataUtils.cpp160
1 files changed, 133 insertions, 27 deletions
diff --git a/contrib/llvm-project/llvm/lib/IR/ProfDataUtils.cpp b/contrib/llvm-project/llvm/lib/IR/ProfDataUtils.cpp
index 29536b0b090c..992ce34e0003 100644
--- a/contrib/llvm-project/llvm/lib/IR/ProfDataUtils.cpp
+++ b/contrib/llvm-project/llvm/lib/IR/ProfDataUtils.cpp
@@ -19,6 +19,7 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Metadata.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/Support/BranchProbability.h"
#include "llvm/Support/CommandLine.h"
@@ -40,12 +41,12 @@ namespace {
// We maintain some constants here to ensure that we access the branch weights
// correctly, and can change the behavior in the future if the layout changes
-// The index at which the weights vector starts
-constexpr unsigned WeightsIdx = 1;
-
// the minimum number of operands for MD_prof nodes with branch weights
constexpr unsigned MinBWOps = 3;
+// the minimum number of operands for MD_prof nodes with value profiles
+constexpr unsigned MinVPOps = 5;
+
// We may want to add support for other MD_prof types, so provide an abstraction
// for checking the metadata type.
bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
@@ -62,7 +63,28 @@ bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
if (!ProfDataName)
return false;
- return ProfDataName->getString().equals(Name);
+ return ProfDataName->getString() == Name;
+}
+
+template <typename T,
+ typename = typename std::enable_if<std::is_arithmetic_v<T>>>
+static void extractFromBranchWeightMD(const MDNode *ProfileData,
+ SmallVectorImpl<T> &Weights) {
+ assert(isBranchWeightMD(ProfileData) && "wrong metadata");
+
+ unsigned NOps = ProfileData->getNumOperands();
+ unsigned WeightsIdx = getBranchWeightOffset(ProfileData);
+ assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
+ Weights.resize(NOps - WeightsIdx);
+
+ for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
+ ConstantInt *Weight =
+ mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
+ assert(Weight && "Malformed branch_weight in MD_prof node");
+ assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) &&
+ "Too many bits for MD_prof branch_weight");
+ Weights[Idx - WeightsIdx] = Weight->getZExtValue();
+ }
}
} // namespace
@@ -70,22 +92,60 @@ bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
namespace llvm {
bool hasProfMD(const Instruction &I) {
- return nullptr != I.getMetadata(LLVMContext::MD_prof);
+ return I.hasMetadata(LLVMContext::MD_prof);
}
bool isBranchWeightMD(const MDNode *ProfileData) {
return isTargetMD(ProfileData, "branch_weights", MinBWOps);
}
+bool isValueProfileMD(const MDNode *ProfileData) {
+ return isTargetMD(ProfileData, "VP", MinVPOps);
+}
+
bool hasBranchWeightMD(const Instruction &I) {
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
return isBranchWeightMD(ProfileData);
}
+bool hasCountTypeMD(const Instruction &I) {
+ auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
+ // Value profiles record count-type information.
+ if (isValueProfileMD(ProfileData))
+ return true;
+ // Conservatively assume non CallBase instruction only get taken/not-taken
+ // branch probability, so not interpret them as count.
+ return isa<CallBase>(I) && !isBranchWeightMD(ProfileData);
+}
+
bool hasValidBranchWeightMD(const Instruction &I) {
return getValidBranchWeightMDNode(I);
}
+bool hasBranchWeightOrigin(const Instruction &I) {
+ auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
+ return hasBranchWeightOrigin(ProfileData);
+}
+
+bool hasBranchWeightOrigin(const MDNode *ProfileData) {
+ if (!isBranchWeightMD(ProfileData))
+ return false;
+ auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1));
+ // NOTE: if we ever have more types of branch weight provenance,
+ // we need to check the string value is "expected". For now, we
+ // supply a more generic API, and avoid the spurious comparisons.
+ assert(ProfDataName == nullptr || ProfDataName->getString() == "expected");
+ return ProfDataName != nullptr;
+}
+
+unsigned getBranchWeightOffset(const MDNode *ProfileData) {
+ return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
+}
+
+unsigned getNumBranchWeights(const MDNode &ProfileData) {
+ return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData);
+}
+
MDNode *getBranchWeightMDNode(const Instruction &I) {
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
if (!isBranchWeightMD(ProfileData))
@@ -95,27 +155,19 @@ MDNode *getBranchWeightMDNode(const Instruction &I) {
MDNode *getValidBranchWeightMDNode(const Instruction &I) {
auto *ProfileData = getBranchWeightMDNode(I);
- if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors())
+ if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors())
return ProfileData;
return nullptr;
}
-void extractFromBranchWeightMD(const MDNode *ProfileData,
- SmallVectorImpl<uint32_t> &Weights) {
- assert(isBranchWeightMD(ProfileData) && "wrong metadata");
-
- unsigned NOps = ProfileData->getNumOperands();
- assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
- Weights.resize(NOps - WeightsIdx);
+void extractFromBranchWeightMD32(const MDNode *ProfileData,
+ SmallVectorImpl<uint32_t> &Weights) {
+ extractFromBranchWeightMD(ProfileData, Weights);
+}
- for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
- ConstantInt *Weight =
- mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
- assert(Weight && "Malformed branch_weight in MD_prof node");
- assert(Weight->getValue().getActiveBits() <= 32 &&
- "Too many bits for uint32_t");
- Weights[Idx - WeightsIdx] = Weight->getZExtValue();
- }
+void extractFromBranchWeightMD64(const MDNode *ProfileData,
+ SmallVectorImpl<uint64_t> &Weights) {
+ extractFromBranchWeightMD(ProfileData, Weights);
}
bool extractBranchWeights(const MDNode *ProfileData,
@@ -161,8 +213,9 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
if (!ProfDataName)
return false;
- if (ProfDataName->getString().equals("branch_weights")) {
- for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) {
+ if (ProfDataName->getString() == "branch_weights") {
+ unsigned Offset = getBranchWeightOffset(ProfileData);
+ for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
assert(V && "Malformed branch_weight in MD_prof node");
TotalVal += V->getValue().getZExtValue();
@@ -170,8 +223,7 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
return true;
}
- if (ProfDataName->getString().equals("VP") &&
- ProfileData->getNumOperands() > 3) {
+ if (ProfDataName->getString() == "VP" && ProfileData->getNumOperands() > 3) {
TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
->getValue()
.getZExtValue();
@@ -184,10 +236,64 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
}
-void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
+void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
+ bool IsExpected) {
MDBuilder MDB(I.getContext());
- MDNode *BranchWeights = MDB.createBranchWeights(Weights);
+ MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
I.setMetadata(LLVMContext::MD_prof, BranchWeights);
}
+void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
+ assert(T != 0 && "Caller should guarantee");
+ auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
+ if (ProfileData == nullptr)
+ return;
+
+ auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
+ if (!ProfDataName || (ProfDataName->getString() != "branch_weights" &&
+ ProfDataName->getString() != "VP"))
+ return;
+
+ if (!hasCountTypeMD(I))
+ return;
+
+ LLVMContext &C = I.getContext();
+
+ MDBuilder MDB(C);
+ SmallVector<Metadata *, 3> Vals;
+ Vals.push_back(ProfileData->getOperand(0));
+ APInt APS(128, S), APT(128, T);
+ if (ProfDataName->getString() == "branch_weights" &&
+ ProfileData->getNumOperands() > 0) {
+ // Using APInt::div may be expensive, but most cases should fit 64 bits.
+ APInt Val(128,
+ mdconst::dyn_extract<ConstantInt>(
+ ProfileData->getOperand(getBranchWeightOffset(ProfileData)))
+ ->getValue()
+ .getZExtValue());
+ Val *= APS;
+ Vals.push_back(MDB.createConstant(ConstantInt::get(
+ Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX))));
+ } else if (ProfDataName->getString() == "VP")
+ for (unsigned i = 1; i < ProfileData->getNumOperands(); i += 2) {
+ // The first value is the key of the value profile, which will not change.
+ Vals.push_back(ProfileData->getOperand(i));
+ uint64_t Count =
+ mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i + 1))
+ ->getValue()
+ .getZExtValue();
+ // Don't scale the magic number.
+ if (Count == NOMORE_ICP_MAGICNUM) {
+ Vals.push_back(ProfileData->getOperand(i + 1));
+ continue;
+ }
+ // Using APInt::div may be expensive, but most cases should fit 64 bits.
+ APInt Val(128, Count);
+ Val *= APS;
+ Vals.push_back(MDB.createConstant(ConstantInt::get(
+ Type::getInt64Ty(C), Val.udiv(APT).getLimitedValue())));
+ }
+ I.setMetadata(LLVMContext::MD_prof, MDNode::get(C, Vals));
+}
+
} // namespace llvm