From 145449b1e420787bb99721a429341fa6be3adfb6 Mon Sep 17 00:00:00 2001 From: Dimitry Andric Date: Sun, 3 Jul 2022 16:10:23 +0200 Subject: Vendor import of llvm-project main llvmorg-15-init-15358-g53dc0f107877. --- llvm/lib/Analysis/ModelUnderTrainingRunner.cpp | 29 ++++++++++++++++++-------- 1 file changed, 20 insertions(+), 9 deletions(-) (limited to 'llvm/lib/Analysis/ModelUnderTrainingRunner.cpp') diff --git a/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp b/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp index fab51d6a7aaf..dc149f326271 100644 --- a/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp +++ b/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp @@ -22,7 +22,7 @@ ModelUnderTrainingRunner::ModelUnderTrainingRunner( LLVMContext &Ctx, const std::string &ModelPath, const std::vector &InputSpecs, const std::vector &OutputSpecs) - : MLModelRunner(Ctx, MLModelRunner::Kind::Development), + : MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()), OutputSpecs(OutputSpecs) { Evaluator = std::make_unique( ModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I].Spec; }, @@ -32,6 +32,10 @@ ModelUnderTrainingRunner::ModelUnderTrainingRunner( Evaluator.reset(); return; } + + for (size_t I = 0, E = InputSpecs.size(); I < E; ++I) { + setUpBufferForTensor(I, InputSpecs[I], Evaluator->getUntypedInput(I)); + } } void *ModelUnderTrainingRunner::evaluateUntyped() { @@ -43,24 +47,31 @@ void *ModelUnderTrainingRunner::evaluateUntyped() { return LastEvaluationResult->getUntypedTensorValue(0); } -void *ModelUnderTrainingRunner::getTensorUntyped(size_t Index) { - return Evaluator->getUntypedInput(Index); -} - std::unique_ptr ModelUnderTrainingRunner::createAndEnsureValid( LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName, const std::vector &InputSpecs, StringRef OutputSpecsPathOverride) { - std::unique_ptr MUTR; if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath, OutputSpecsPathOverride)) - MUTR.reset(new ModelUnderTrainingRunner(Ctx, ModelPath, InputSpecs, - *MaybeOutputSpecs)); + return createAndEnsureValid(Ctx, ModelPath, DecisionName, InputSpecs, + *MaybeOutputSpecs); + Ctx.emitError("Could not load the policy model from the provided path"); + return nullptr; +} + +std::unique_ptr +ModelUnderTrainingRunner::createAndEnsureValid( + LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName, + const std::vector &InputSpecs, + const std::vector &OutputSpecs) { + std::unique_ptr MUTR; + MUTR.reset( + new ModelUnderTrainingRunner(Ctx, ModelPath, InputSpecs, OutputSpecs)); if (MUTR && MUTR->isValid()) return MUTR; - Ctx.emitError("Could not load the policy model from the provided path"); + Ctx.emitError("Could not load or create model evaluator."); return nullptr; } -- cgit v1.2.3