diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2022-07-03 14:10:23 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2022-07-03 14:10:23 +0000 |
commit | 145449b1e420787bb99721a429341fa6be3adfb6 (patch) | |
tree | 1d56ae694a6de602e348dd80165cf881a36600ed /llvm/lib/Analysis/ModelUnderTrainingRunner.cpp | |
parent | ecbca9f5fb7d7613d2b94982c4825eb0d33d6842 (diff) |
Diffstat (limited to 'llvm/lib/Analysis/ModelUnderTrainingRunner.cpp')
-rw-r--r-- | llvm/lib/Analysis/ModelUnderTrainingRunner.cpp | 29 |
1 files changed, 20 insertions, 9 deletions
diff --git a/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp b/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp index fab51d6a7aaf..dc149f326271 100644 --- a/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp +++ b/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp @@ -22,7 +22,7 @@ ModelUnderTrainingRunner::ModelUnderTrainingRunner( LLVMContext &Ctx, const std::string &ModelPath, const std::vector<TensorSpec> &InputSpecs, const std::vector<LoggedFeatureSpec> &OutputSpecs) - : MLModelRunner(Ctx, MLModelRunner::Kind::Development), + : MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()), OutputSpecs(OutputSpecs) { Evaluator = std::make_unique<TFModelEvaluator>( ModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I].Spec; }, @@ -32,6 +32,10 @@ ModelUnderTrainingRunner::ModelUnderTrainingRunner( Evaluator.reset(); return; } + + for (size_t I = 0, E = InputSpecs.size(); I < E; ++I) { + setUpBufferForTensor(I, InputSpecs[I], Evaluator->getUntypedInput(I)); + } } void *ModelUnderTrainingRunner::evaluateUntyped() { @@ -43,24 +47,31 @@ void *ModelUnderTrainingRunner::evaluateUntyped() { return LastEvaluationResult->getUntypedTensorValue(0); } -void *ModelUnderTrainingRunner::getTensorUntyped(size_t Index) { - return Evaluator->getUntypedInput(Index); -} - std::unique_ptr<ModelUnderTrainingRunner> ModelUnderTrainingRunner::createAndEnsureValid( LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName, const std::vector<TensorSpec> &InputSpecs, StringRef OutputSpecsPathOverride) { - std::unique_ptr<ModelUnderTrainingRunner> MUTR; if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath, OutputSpecsPathOverride)) - MUTR.reset(new ModelUnderTrainingRunner(Ctx, ModelPath, InputSpecs, - *MaybeOutputSpecs)); + return createAndEnsureValid(Ctx, ModelPath, DecisionName, InputSpecs, + *MaybeOutputSpecs); + Ctx.emitError("Could not load the policy model from the provided path"); + return nullptr; +} + +std::unique_ptr<ModelUnderTrainingRunner> +ModelUnderTrainingRunner::createAndEnsureValid( + LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName, + const std::vector<TensorSpec> &InputSpecs, + const std::vector<LoggedFeatureSpec> &OutputSpecs) { + std::unique_ptr<ModelUnderTrainingRunner> MUTR; + MUTR.reset( + new ModelUnderTrainingRunner(Ctx, ModelPath, InputSpecs, OutputSpecs)); if (MUTR && MUTR->isValid()) return MUTR; - Ctx.emitError("Could not load the policy model from the provided path"); + Ctx.emitError("Could not load or create model evaluator."); return nullptr; } |