diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp | 131 |
1 files changed, 105 insertions, 26 deletions
diff --git a/contrib/llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp b/contrib/llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp index dc149f326271..3892b09e5b63 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp @@ -11,22 +11,94 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/STLExtras.h" #include "llvm/Config/config.h" -#if defined(LLVM_HAVE_TF_API) - +#if defined(LLVM_HAVE_TFLITE) #include "llvm/Analysis/ModelUnderTrainingRunner.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" +#include <optional> using namespace llvm; +namespace { +struct LoggedFeatureSpec { + TensorSpec Spec; + std::optional<std::string> LoggingName; +}; + +std::optional<std::vector<LoggedFeatureSpec>> +loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName, + StringRef ModelPath, StringRef SpecFileOverride) { + SmallVector<char, 128> OutputSpecsPath; + StringRef FileName = SpecFileOverride; + if (FileName.empty()) { + llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json"); + FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()}; + } + + auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName); + if (!BufferOrError) { + Ctx.emitError("Error opening output specs file: " + FileName + " : " + + BufferOrError.getError().message()); + return std::nullopt; + } + auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer()); + if (!ParsedJSONValues) { + Ctx.emitError("Could not parse specs file: " + FileName); + return std::nullopt; + } + auto ValuesArray = ParsedJSONValues->getAsArray(); + if (!ValuesArray) { + Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, " + "logging_name:<name>} dictionaries"); + return std::nullopt; + } + std::vector<LoggedFeatureSpec> Ret; + for (const auto &Value : *ValuesArray) + if (const auto *Obj = Value.getAsObject()) + if (const auto *SpecPart = Obj->get("tensor_spec")) + if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart)) + if (auto LoggingName = Obj->getString("logging_name")) { + if (!TensorSpec->isElementType<int64_t>() && + !TensorSpec->isElementType<int32_t>() && + !TensorSpec->isElementType<float>()) { + Ctx.emitError( + "Only int64, int32, and float tensors are supported. " + "Found unsupported type for tensor named " + + TensorSpec->name()); + return std::nullopt; + } + Ret.push_back({*TensorSpec, LoggingName->str()}); + } + + if (ValuesArray->size() != Ret.size()) { + Ctx.emitError( + "Unable to parse output spec. It should be a json file containing an " + "array of dictionaries. Each dictionary must have a 'tensor_spec' key, " + "with a json object describing a TensorSpec; and a 'logging_name' key, " + "which is a string to use as name when logging this tensor in the " + "training log."); + return std::nullopt; + } + if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) { + Ctx.emitError("The first output spec must describe the decision tensor, " + "and must have the logging_name " + + StringRef(ExpectedDecisionName)); + return std::nullopt; + } + return Ret; +} +} // namespace ModelUnderTrainingRunner::ModelUnderTrainingRunner( LLVMContext &Ctx, const std::string &ModelPath, const std::vector<TensorSpec> &InputSpecs, - const std::vector<LoggedFeatureSpec> &OutputSpecs) + const std::vector<TensorSpec> &OutputSpecs, + const std::vector<TensorSpec> &ExtraOutputsForLogging) : MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()), - OutputSpecs(OutputSpecs) { - Evaluator = std::make_unique<TFModelEvaluator>( - ModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I].Spec; }, - OutputSpecs.size()); + OutputSpecs(OutputSpecs), ExtraOutputsForLogging(ExtraOutputsForLogging) { + Evaluator = + std::make_unique<TFModelEvaluator>(ModelPath, InputSpecs, OutputSpecs); if (!Evaluator || !Evaluator->isValid()) { Ctx.emitError("Failed to create saved model evaluator"); Evaluator.reset(); @@ -40,7 +112,7 @@ ModelUnderTrainingRunner::ModelUnderTrainingRunner( void *ModelUnderTrainingRunner::evaluateUntyped() { LastEvaluationResult = Evaluator->evaluate(); - if (!LastEvaluationResult.hasValue()) { + if (!LastEvaluationResult.has_value()) { Ctx.emitError("Error evaluating model."); return nullptr; } @@ -53,26 +125,33 @@ ModelUnderTrainingRunner::createAndEnsureValid( const std::vector<TensorSpec> &InputSpecs, StringRef OutputSpecsPathOverride) { if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath, - OutputSpecsPathOverride)) - return createAndEnsureValid(Ctx, ModelPath, DecisionName, InputSpecs, - *MaybeOutputSpecs); - Ctx.emitError("Could not load the policy model from the provided path"); - return nullptr; -} + OutputSpecsPathOverride)) { + std::unique_ptr<ModelUnderTrainingRunner> MUTR; + std::vector<TensorSpec> OutputSpecs; + std::vector<TensorSpec> ExtraOutputsForLogging; + append_range(OutputSpecs, + map_range(*MaybeOutputSpecs, [](const LoggedFeatureSpec &LFS) { + return LFS.Spec; + })); + append_range(ExtraOutputsForLogging, + map_range(drop_begin(*MaybeOutputSpecs), + [](const LoggedFeatureSpec &LFS) { + return TensorSpec(LFS.LoggingName + ? *LFS.LoggingName + : LFS.Spec.name(), + LFS.Spec); + })); -std::unique_ptr<ModelUnderTrainingRunner> -ModelUnderTrainingRunner::createAndEnsureValid( - LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName, - const std::vector<TensorSpec> &InputSpecs, - const std::vector<LoggedFeatureSpec> &OutputSpecs) { - std::unique_ptr<ModelUnderTrainingRunner> MUTR; - MUTR.reset( - new ModelUnderTrainingRunner(Ctx, ModelPath, InputSpecs, OutputSpecs)); - if (MUTR && MUTR->isValid()) - return MUTR; + MUTR.reset(new ModelUnderTrainingRunner( + Ctx, ModelPath, InputSpecs, OutputSpecs, ExtraOutputsForLogging)); + if (MUTR && MUTR->isValid()) + return MUTR; - Ctx.emitError("Could not load or create model evaluator."); + Ctx.emitError("Could not load or create model evaluator."); + return nullptr; + } + Ctx.emitError("Could not load the policy model from the provided path"); return nullptr; } -#endif // defined(LLVM_HAVE_TF_API) +#endif // defined(LLVM_HAVE_TFLITE) |