diff options
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/vespa/searchlib/features/onnx_feature.cpp | 15 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/features/onnx_feature.h | 8 |
2 files changed, 16 insertions, 7 deletions
diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp index 87e5ef2a5c2..e9fecb3578e 100644 --- a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp +++ b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp @@ -69,6 +69,8 @@ public: OnnxBlueprint::OnnxBlueprint() : Blueprint("onnxModel"), + _cache_token(), + _debug_model(), _model(nullptr), _wire_info() { @@ -80,15 +82,18 @@ bool OnnxBlueprint::setup(const IIndexEnvironment &env, const ParameterList ¶ms) { - auto optimize = (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP) - ? Onnx::Optimize::DISABLE - : Onnx::Optimize::ENABLE; auto model_cfg = env.getOnnxModel(params[0].getValue()); if (!model_cfg) { return fail("no model with name '%s' found", params[0].getValue().c_str()); } try { - _model = std::make_unique<Onnx>(model_cfg->file_path(), optimize); + if (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP) { + _debug_model = std::make_unique<Onnx>(model_cfg->file_path(), Optimize::DISABLE); + _model = _debug_model.get(); + } else { + _cache_token = OnnxModelCache::load(model_cfg->file_path()); + _model = &(_cache_token->get()); + } } catch (std::exception &ex) { return fail("model setup failed: %s", ex.what()); } @@ -132,7 +137,7 @@ OnnxBlueprint::setup(const IIndexEnvironment &env, FeatureExecutor & OnnxBlueprint::createExecutor(const IQueryEnvironment &, Stash &stash) const { - assert(_model); + assert(_model != nullptr); return stash.create<OnnxFeatureExecutor>(*_model, _wire_info); } diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.h b/searchlib/src/vespa/searchlib/features/onnx_feature.h index 6a63e7276c2..ed0fbc502f0 100644 --- a/searchlib/src/vespa/searchlib/features/onnx_feature.h +++ b/searchlib/src/vespa/searchlib/features/onnx_feature.h @@ -3,7 +3,7 @@ #pragma once #include <vespa/searchlib/fef/blueprint.h> -#include <vespa/eval/onnx/onnx_wrapper.h> +#include <vespa/eval/onnx/onnx_model_cache.h> namespace search::features { @@ -13,7 +13,11 @@ namespace search::features { class OnnxBlueprint : public fef::Blueprint { private: using Onnx = vespalib::eval::Onnx; - std::unique_ptr<Onnx> _model; + using Optimize = vespalib::eval::Onnx::Optimize; + using OnnxModelCache = vespalib::eval::OnnxModelCache; + OnnxModelCache::Token::UP _cache_token; + std::unique_ptr<Onnx> _debug_model; + const Onnx *_model; Onnx::WireInfo _wire_info; public: OnnxBlueprint(); |