aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2022-07-03 14:10:23 +0000
committerDimitry Andric <dim@FreeBSD.org>2022-07-03 14:10:23 +0000
commit145449b1e420787bb99721a429341fa6be3adfb6 (patch)
tree1d56ae694a6de602e348dd80165cf881a36600ed /llvm/lib/Analysis/ModelUnderTrainingRunner.cpp
parentecbca9f5fb7d7613d2b94982c4825eb0d33d6842 (diff)
Diffstat (limited to 'llvm/lib/Analysis/ModelUnderTrainingRunner.cpp')
-rw-r--r--llvm/lib/Analysis/ModelUnderTrainingRunner.cpp29
1 files changed, 20 insertions, 9 deletions
diff --git a/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp b/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp
index fab51d6a7aaf..dc149f326271 100644
--- a/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp
+++ b/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp
@@ -22,7 +22,7 @@ ModelUnderTrainingRunner::ModelUnderTrainingRunner(
LLVMContext &Ctx, const std::string &ModelPath,
const std::vector<TensorSpec> &InputSpecs,
const std::vector<LoggedFeatureSpec> &OutputSpecs)
- : MLModelRunner(Ctx, MLModelRunner::Kind::Development),
+ : MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()),
OutputSpecs(OutputSpecs) {
Evaluator = std::make_unique<TFModelEvaluator>(
ModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I].Spec; },
@@ -32,6 +32,10 @@ ModelUnderTrainingRunner::ModelUnderTrainingRunner(
Evaluator.reset();
return;
}
+
+ for (size_t I = 0, E = InputSpecs.size(); I < E; ++I) {
+ setUpBufferForTensor(I, InputSpecs[I], Evaluator->getUntypedInput(I));
+ }
}
void *ModelUnderTrainingRunner::evaluateUntyped() {
@@ -43,24 +47,31 @@ void *ModelUnderTrainingRunner::evaluateUntyped() {
return LastEvaluationResult->getUntypedTensorValue(0);
}
-void *ModelUnderTrainingRunner::getTensorUntyped(size_t Index) {
- return Evaluator->getUntypedInput(Index);
-}
-
std::unique_ptr<ModelUnderTrainingRunner>
ModelUnderTrainingRunner::createAndEnsureValid(
LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
const std::vector<TensorSpec> &InputSpecs,
StringRef OutputSpecsPathOverride) {
- std::unique_ptr<ModelUnderTrainingRunner> MUTR;
if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath,
OutputSpecsPathOverride))
- MUTR.reset(new ModelUnderTrainingRunner(Ctx, ModelPath, InputSpecs,
- *MaybeOutputSpecs));
+ return createAndEnsureValid(Ctx, ModelPath, DecisionName, InputSpecs,
+ *MaybeOutputSpecs);
+ Ctx.emitError("Could not load the policy model from the provided path");
+ return nullptr;
+}
+
+std::unique_ptr<ModelUnderTrainingRunner>
+ModelUnderTrainingRunner::createAndEnsureValid(
+ LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
+ const std::vector<TensorSpec> &InputSpecs,
+ const std::vector<LoggedFeatureSpec> &OutputSpecs) {
+ std::unique_ptr<ModelUnderTrainingRunner> MUTR;
+ MUTR.reset(
+ new ModelUnderTrainingRunner(Ctx, ModelPath, InputSpecs, OutputSpecs));
if (MUTR && MUTR->isValid())
return MUTR;
- Ctx.emitError("Could not load the policy model from the provided path");
+ Ctx.emitError("Could not load or create model evaluator.");
return nullptr;
}