diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp b/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp index dcee8d40c53d..e236890aa2bc 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp @@ -32,7 +32,7 @@ static cl::opt<bool> UseSimpleLogger("tfutils-use-simplelogger", cl::init(true), cl::Hidden, cl::desc("Output simple (non-protobuf) log.")); -void Logger::writeHeader() { +void Logger::writeHeader(std::optional<TensorSpec> AdviceSpec) { json::OStream JOS(*OS); JOS.object([&]() { JOS.attributeArray("features", [&]() { @@ -44,6 +44,11 @@ void Logger::writeHeader() { RewardSpec.toJSON(JOS); JOS.attributeEnd(); } + if (AdviceSpec.has_value()) { + JOS.attributeBegin("advice"); + AdviceSpec->toJSON(JOS); + JOS.attributeEnd(); + } }); *OS << "\n"; } @@ -81,8 +86,9 @@ void Logger::logRewardImpl(const char *RawData) { Logger::Logger(std::unique_ptr<raw_ostream> OS, const std::vector<TensorSpec> &FeatureSpecs, - const TensorSpec &RewardSpec, bool IncludeReward) + const TensorSpec &RewardSpec, bool IncludeReward, + std::optional<TensorSpec> AdviceSpec) : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec), IncludeReward(IncludeReward) { - writeHeader(); + writeHeader(AdviceSpec); } |