diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2023-09-02 21:17:18 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2023-12-08 17:34:50 +0000 |
commit | 06c3fb2749bda94cb5201f81ffdb8fa6c3161b2e (patch) | |
tree | 62f873df87c7c675557a179e0c4c83fe9f3087bc /contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp | |
parent | cf037972ea8863e2bab7461d77345367d2c1e054 (diff) | |
parent | 7fa27ce4a07f19b07799a767fc29416f3b625afb (diff) |
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp b/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp index dcee8d40c53d..e236890aa2bc 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp @@ -32,7 +32,7 @@ static cl::opt<bool> UseSimpleLogger("tfutils-use-simplelogger", cl::init(true), cl::Hidden, cl::desc("Output simple (non-protobuf) log.")); -void Logger::writeHeader() { +void Logger::writeHeader(std::optional<TensorSpec> AdviceSpec) { json::OStream JOS(*OS); JOS.object([&]() { JOS.attributeArray("features", [&]() { @@ -44,6 +44,11 @@ void Logger::writeHeader() { RewardSpec.toJSON(JOS); JOS.attributeEnd(); } + if (AdviceSpec.has_value()) { + JOS.attributeBegin("advice"); + AdviceSpec->toJSON(JOS); + JOS.attributeEnd(); + } }); *OS << "\n"; } @@ -81,8 +86,9 @@ void Logger::logRewardImpl(const char *RawData) { Logger::Logger(std::unique_ptr<raw_ostream> OS, const std::vector<TensorSpec> &FeatureSpecs, - const TensorSpec &RewardSpec, bool IncludeReward) + const TensorSpec &RewardSpec, bool IncludeReward, + std::optional<TensorSpec> AdviceSpec) : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec), IncludeReward(IncludeReward) { - writeHeader(); + writeHeader(AdviceSpec); } |