aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2023-09-02 21:17:18 +0000
committerDimitry Andric <dim@FreeBSD.org>2023-12-08 17:34:50 +0000
commit06c3fb2749bda94cb5201f81ffdb8fa6c3161b2e (patch)
tree62f873df87c7c675557a179e0c4c83fe9f3087bc /contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp
parentcf037972ea8863e2bab7461d77345367d2c1e054 (diff)
parent7fa27ce4a07f19b07799a767fc29416f3b625afb (diff)
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp12
1 files changed, 9 insertions, 3 deletions
diff --git a/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp b/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp
index dcee8d40c53d..e236890aa2bc 100644
--- a/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp
+++ b/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp
@@ -32,7 +32,7 @@ static cl::opt<bool>
UseSimpleLogger("tfutils-use-simplelogger", cl::init(true), cl::Hidden,
cl::desc("Output simple (non-protobuf) log."));
-void Logger::writeHeader() {
+void Logger::writeHeader(std::optional<TensorSpec> AdviceSpec) {
json::OStream JOS(*OS);
JOS.object([&]() {
JOS.attributeArray("features", [&]() {
@@ -44,6 +44,11 @@ void Logger::writeHeader() {
RewardSpec.toJSON(JOS);
JOS.attributeEnd();
}
+ if (AdviceSpec.has_value()) {
+ JOS.attributeBegin("advice");
+ AdviceSpec->toJSON(JOS);
+ JOS.attributeEnd();
+ }
});
*OS << "\n";
}
@@ -81,8 +86,9 @@ void Logger::logRewardImpl(const char *RawData) {
Logger::Logger(std::unique_ptr<raw_ostream> OS,
const std::vector<TensorSpec> &FeatureSpecs,
- const TensorSpec &RewardSpec, bool IncludeReward)
+ const TensorSpec &RewardSpec, bool IncludeReward,
+ std::optional<TensorSpec> AdviceSpec)
: OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
IncludeReward(IncludeReward) {
- writeHeader();
+ writeHeader(AdviceSpec);
}