summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2021-06-11 13:59:15 +0000
committerHåvard Pettersen <havardpe@oath.com>2021-06-11 15:22:03 +0000
commit60642e69ea89ef4386763db5b3a54f212b9a557b (patch)
treee0ecff3caa48cc2c569b0b8efff38849741e7157 /searchlib
parentf163dd2b355819e490fd8f2e0327a3a6950bb94a (diff)
added onnx model cache
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/vespa/searchlib/features/onnx_feature.cpp15
-rw-r--r--searchlib/src/vespa/searchlib/features/onnx_feature.h8
2 files changed, 16 insertions, 7 deletions
diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
index 87e5ef2a5c2..e9fecb3578e 100644
--- a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
+++ b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
@@ -69,6 +69,8 @@ public:
OnnxBlueprint::OnnxBlueprint()
: Blueprint("onnxModel"),
+ _cache_token(),
+ _debug_model(),
_model(nullptr),
_wire_info()
{
@@ -80,15 +82,18 @@ bool
OnnxBlueprint::setup(const IIndexEnvironment &env,
const ParameterList &params)
{
- auto optimize = (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP)
- ? Onnx::Optimize::DISABLE
- : Onnx::Optimize::ENABLE;
auto model_cfg = env.getOnnxModel(params[0].getValue());
if (!model_cfg) {
return fail("no model with name '%s' found", params[0].getValue().c_str());
}
try {
- _model = std::make_unique<Onnx>(model_cfg->file_path(), optimize);
+ if (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP) {
+ _debug_model = std::make_unique<Onnx>(model_cfg->file_path(), Optimize::DISABLE);
+ _model = _debug_model.get();
+ } else {
+ _cache_token = OnnxModelCache::load(model_cfg->file_path());
+ _model = &(_cache_token->get());
+ }
} catch (std::exception &ex) {
return fail("model setup failed: %s", ex.what());
}
@@ -132,7 +137,7 @@ OnnxBlueprint::setup(const IIndexEnvironment &env,
FeatureExecutor &
OnnxBlueprint::createExecutor(const IQueryEnvironment &, Stash &stash) const
{
- assert(_model);
+ assert(_model != nullptr);
return stash.create<OnnxFeatureExecutor>(*_model, _wire_info);
}
diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.h b/searchlib/src/vespa/searchlib/features/onnx_feature.h
index 6a63e7276c2..ed0fbc502f0 100644
--- a/searchlib/src/vespa/searchlib/features/onnx_feature.h
+++ b/searchlib/src/vespa/searchlib/features/onnx_feature.h
@@ -3,7 +3,7 @@
#pragma once
#include <vespa/searchlib/fef/blueprint.h>
-#include <vespa/eval/onnx/onnx_wrapper.h>
+#include <vespa/eval/onnx/onnx_model_cache.h>
namespace search::features {
@@ -13,7 +13,11 @@ namespace search::features {
class OnnxBlueprint : public fef::Blueprint {
private:
using Onnx = vespalib::eval::Onnx;
- std::unique_ptr<Onnx> _model;
+ using Optimize = vespalib::eval::Onnx::Optimize;
+ using OnnxModelCache = vespalib::eval::OnnxModelCache;
+ OnnxModelCache::Token::UP _cache_token;
+ std::unique_ptr<Onnx> _debug_model;
+ const Onnx *_model;
Onnx::WireInfo _wire_info;
public:
OnnxBlueprint();