aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp131
1 files changed, 105 insertions, 26 deletions
diff --git a/contrib/llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp b/contrib/llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp
index dc149f326271..3892b09e5b63 100644
--- a/contrib/llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp
+++ b/contrib/llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp
@@ -11,22 +11,94 @@
//
//===----------------------------------------------------------------------===//
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Config/config.h"
-#if defined(LLVM_HAVE_TF_API)
-
+#if defined(LLVM_HAVE_TFLITE)
#include "llvm/Analysis/ModelUnderTrainingRunner.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/Path.h"
+#include <optional>
using namespace llvm;
+namespace {
+struct LoggedFeatureSpec {
+ TensorSpec Spec;
+ std::optional<std::string> LoggingName;
+};
+
+std::optional<std::vector<LoggedFeatureSpec>>
+loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName,
+ StringRef ModelPath, StringRef SpecFileOverride) {
+ SmallVector<char, 128> OutputSpecsPath;
+ StringRef FileName = SpecFileOverride;
+ if (FileName.empty()) {
+ llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json");
+ FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()};
+ }
+
+ auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName);
+ if (!BufferOrError) {
+ Ctx.emitError("Error opening output specs file: " + FileName + " : " +
+ BufferOrError.getError().message());
+ return std::nullopt;
+ }
+ auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer());
+ if (!ParsedJSONValues) {
+ Ctx.emitError("Could not parse specs file: " + FileName);
+ return std::nullopt;
+ }
+ auto ValuesArray = ParsedJSONValues->getAsArray();
+ if (!ValuesArray) {
+ Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, "
+ "logging_name:<name>} dictionaries");
+ return std::nullopt;
+ }
+ std::vector<LoggedFeatureSpec> Ret;
+ for (const auto &Value : *ValuesArray)
+ if (const auto *Obj = Value.getAsObject())
+ if (const auto *SpecPart = Obj->get("tensor_spec"))
+ if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart))
+ if (auto LoggingName = Obj->getString("logging_name")) {
+ if (!TensorSpec->isElementType<int64_t>() &&
+ !TensorSpec->isElementType<int32_t>() &&
+ !TensorSpec->isElementType<float>()) {
+ Ctx.emitError(
+ "Only int64, int32, and float tensors are supported. "
+ "Found unsupported type for tensor named " +
+ TensorSpec->name());
+ return std::nullopt;
+ }
+ Ret.push_back({*TensorSpec, LoggingName->str()});
+ }
+
+ if (ValuesArray->size() != Ret.size()) {
+ Ctx.emitError(
+ "Unable to parse output spec. It should be a json file containing an "
+ "array of dictionaries. Each dictionary must have a 'tensor_spec' key, "
+ "with a json object describing a TensorSpec; and a 'logging_name' key, "
+ "which is a string to use as name when logging this tensor in the "
+ "training log.");
+ return std::nullopt;
+ }
+ if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) {
+ Ctx.emitError("The first output spec must describe the decision tensor, "
+ "and must have the logging_name " +
+ StringRef(ExpectedDecisionName));
+ return std::nullopt;
+ }
+ return Ret;
+}
+} // namespace
ModelUnderTrainingRunner::ModelUnderTrainingRunner(
LLVMContext &Ctx, const std::string &ModelPath,
const std::vector<TensorSpec> &InputSpecs,
- const std::vector<LoggedFeatureSpec> &OutputSpecs)
+ const std::vector<TensorSpec> &OutputSpecs,
+ const std::vector<TensorSpec> &ExtraOutputsForLogging)
: MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()),
- OutputSpecs(OutputSpecs) {
- Evaluator = std::make_unique<TFModelEvaluator>(
- ModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I].Spec; },
- OutputSpecs.size());
+ OutputSpecs(OutputSpecs), ExtraOutputsForLogging(ExtraOutputsForLogging) {
+ Evaluator =
+ std::make_unique<TFModelEvaluator>(ModelPath, InputSpecs, OutputSpecs);
if (!Evaluator || !Evaluator->isValid()) {
Ctx.emitError("Failed to create saved model evaluator");
Evaluator.reset();
@@ -40,7 +112,7 @@ ModelUnderTrainingRunner::ModelUnderTrainingRunner(
void *ModelUnderTrainingRunner::evaluateUntyped() {
LastEvaluationResult = Evaluator->evaluate();
- if (!LastEvaluationResult.hasValue()) {
+ if (!LastEvaluationResult.has_value()) {
Ctx.emitError("Error evaluating model.");
return nullptr;
}
@@ -53,26 +125,33 @@ ModelUnderTrainingRunner::createAndEnsureValid(
const std::vector<TensorSpec> &InputSpecs,
StringRef OutputSpecsPathOverride) {
if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath,
- OutputSpecsPathOverride))
- return createAndEnsureValid(Ctx, ModelPath, DecisionName, InputSpecs,
- *MaybeOutputSpecs);
- Ctx.emitError("Could not load the policy model from the provided path");
- return nullptr;
-}
+ OutputSpecsPathOverride)) {
+ std::unique_ptr<ModelUnderTrainingRunner> MUTR;
+ std::vector<TensorSpec> OutputSpecs;
+ std::vector<TensorSpec> ExtraOutputsForLogging;
+ append_range(OutputSpecs,
+ map_range(*MaybeOutputSpecs, [](const LoggedFeatureSpec &LFS) {
+ return LFS.Spec;
+ }));
+ append_range(ExtraOutputsForLogging,
+ map_range(drop_begin(*MaybeOutputSpecs),
+ [](const LoggedFeatureSpec &LFS) {
+ return TensorSpec(LFS.LoggingName
+ ? *LFS.LoggingName
+ : LFS.Spec.name(),
+ LFS.Spec);
+ }));
-std::unique_ptr<ModelUnderTrainingRunner>
-ModelUnderTrainingRunner::createAndEnsureValid(
- LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
- const std::vector<TensorSpec> &InputSpecs,
- const std::vector<LoggedFeatureSpec> &OutputSpecs) {
- std::unique_ptr<ModelUnderTrainingRunner> MUTR;
- MUTR.reset(
- new ModelUnderTrainingRunner(Ctx, ModelPath, InputSpecs, OutputSpecs));
- if (MUTR && MUTR->isValid())
- return MUTR;
+ MUTR.reset(new ModelUnderTrainingRunner(
+ Ctx, ModelPath, InputSpecs, OutputSpecs, ExtraOutputsForLogging));
+ if (MUTR && MUTR->isValid())
+ return MUTR;
- Ctx.emitError("Could not load or create model evaluator.");
+ Ctx.emitError("Could not load or create model evaluator.");
+ return nullptr;
+ }
+ Ctx.emitError("Could not load the policy model from the provided path");
return nullptr;
}
-#endif // defined(LLVM_HAVE_TF_API)
+#endif // defined(LLVM_HAVE_TFLITE)