summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp20
-rw-r--r--eval/src/vespa/eval/onnx/CMakeLists.txt1
-rw-r--r--eval/src/vespa/eval/onnx/onnx_model_cache.cpp51
-rw-r--r--eval/src/vespa/eval/onnx/onnx_model_cache.h58
-rw-r--r--searchlib/src/vespa/searchlib/features/onnx_feature.cpp15
-rw-r--r--searchlib/src/vespa/searchlib/features/onnx_feature.h8
6 files changed, 146 insertions, 7 deletions
diff --git a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
index 54f958f8111..6b45172ef80 100644
--- a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
+++ b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
@@ -3,6 +3,7 @@
#include <vespa/eval/eval/tensor_spec.h>
#include <vespa/eval/eval/int8float.h>
#include <vespa/eval/onnx/onnx_wrapper.h>
+#include <vespa/eval/onnx/onnx_model_cache.h>
#include <vespa/vespalib/util/bfloat16.h>
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/gtest/gtest.h>
@@ -443,4 +444,23 @@ TEST(OnnxTest, default_allocator_type) {
fprintf(stderr, "default allocator type: %d\n", int(res));
}
+TEST(OnnxModelCacheTest, share_and_evict_onnx_models) {
+ {
+ auto simple1 = OnnxModelCache::load(simple_model);
+ auto simple2 = OnnxModelCache::load(simple_model);
+ auto dynamic1 = OnnxModelCache::load(dynamic_model);
+ auto dynamic2 = OnnxModelCache::load(dynamic_model);
+ auto dynamic3 = OnnxModelCache::load(dynamic_model);
+ EXPECT_EQ(simple1->get().inputs().size(), 3);
+ EXPECT_EQ(dynamic1->get().inputs().size(), 3);
+ EXPECT_EQ(&(simple1->get()), &(simple2->get()));
+ EXPECT_EQ(&(dynamic1->get()), &(dynamic2->get()));
+ EXPECT_EQ(&(dynamic2->get()), &(dynamic3->get()));
+ EXPECT_EQ(OnnxModelCache::num_cached(), 2);
+ EXPECT_EQ(OnnxModelCache::count_refs(), 5);
+ }
+ EXPECT_EQ(OnnxModelCache::num_cached(), 0);
+ EXPECT_EQ(OnnxModelCache::count_refs(), 0);
+}
+
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/eval/src/vespa/eval/onnx/CMakeLists.txt b/eval/src/vespa/eval/onnx/CMakeLists.txt
index 9b18557c036..40444936d02 100644
--- a/eval/src/vespa/eval/onnx/CMakeLists.txt
+++ b/eval/src/vespa/eval/onnx/CMakeLists.txt
@@ -2,5 +2,6 @@
vespa_add_library(eval_onnx OBJECT
SOURCES
+ onnx_model_cache.cpp
onnx_wrapper.cpp
)
diff --git a/eval/src/vespa/eval/onnx/onnx_model_cache.cpp b/eval/src/vespa/eval/onnx/onnx_model_cache.cpp
new file mode 100644
index 00000000000..01d5fdd9c84
--- /dev/null
+++ b/eval/src/vespa/eval/onnx/onnx_model_cache.cpp
@@ -0,0 +1,51 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "onnx_model_cache.h"
+
+namespace vespalib::eval {
+
+std::mutex OnnxModelCache::_lock{};
+OnnxModelCache::Map OnnxModelCache::_cached{};
+
+void
+OnnxModelCache::release(Map::iterator entry)
+{
+ std::lock_guard<std::mutex> guard(_lock);
+ if (--(entry->second.num_refs) == 0) {
+ _cached.erase(entry);
+ }
+}
+
+OnnxModelCache::Token::UP
+OnnxModelCache::load(const vespalib::string &model_file)
+{
+ std::lock_guard<std::mutex> guard(_lock);
+ auto pos = _cached.find(model_file);
+ if (pos == _cached.end()) {
+ auto model = std::make_unique<Onnx>(model_file, Onnx::Optimize::ENABLE);
+ auto res = _cached.emplace(model_file, std::move(model));
+ assert(res.second);
+ pos = res.first;
+ }
+ return std::make_unique<Token>(pos, ctor_tag());
+}
+
+size_t
+OnnxModelCache::num_cached()
+{
+ std::lock_guard<std::mutex> guard(_lock);
+ return _cached.size();
+}
+
+size_t
+OnnxModelCache::count_refs()
+{
+ std::lock_guard<std::mutex> guard(_lock);
+ size_t refs = 0;
+ for (const auto &entry: _cached) {
+ refs += entry.second.num_refs;
+ }
+ return refs;
+}
+
+}
diff --git a/eval/src/vespa/eval/onnx/onnx_model_cache.h b/eval/src/vespa/eval/onnx/onnx_model_cache.h
new file mode 100644
index 00000000000..35d5fefa061
--- /dev/null
+++ b/eval/src/vespa/eval/onnx/onnx_model_cache.h
@@ -0,0 +1,58 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "onnx_wrapper.h"
+#include <vespa/vespalib/stllike/string.h>
+#include <memory>
+#include <mutex>
+#include <map>
+
+namespace vespalib::eval {
+
+/**
+ * Cache used to share loaded onnx models between users. The cache
+ * itself will not keep anything alive, but will let you find loaded
+ * models that are currently in use by others.
+ **/
+class OnnxModelCache
+{
+private:
+ struct ctor_tag {};
+ using Key = vespalib::string;
+ struct Value {
+ size_t num_refs;
+ std::unique_ptr<Onnx> model;
+ Value(std::unique_ptr<Onnx> model_in) : num_refs(0), model(std::move(model_in)) {}
+ const Onnx &get() { return *model; }
+ };
+ using Map = std::map<Key,Value>;
+ static std::mutex _lock;
+ static Map _cached;
+
+ static void release(Map::iterator entry);
+
+public:
+ class Token
+ {
+ private:
+ OnnxModelCache::Map::iterator _entry;
+ public:
+ Token(Token &&) = delete;
+ Token(const Token &) = delete;
+ Token &operator=(Token &&) = delete;
+ Token &operator=(const Token &) = delete;
+ using UP = std::unique_ptr<Token>;
+ explicit Token(OnnxModelCache::Map::iterator entry, ctor_tag) : _entry(entry) {
+ ++_entry->second.num_refs;
+ }
+ const Onnx &get() const { return _entry->second.get(); }
+ ~Token() { OnnxModelCache::release(_entry); }
+ };
+
+ static Token::UP load(const vespalib::string &model_file);
+ static size_t num_cached();
+ static size_t count_refs();
+};
+
+}
diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
index 87e5ef2a5c2..e9fecb3578e 100644
--- a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
+++ b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
@@ -69,6 +69,8 @@ public:
OnnxBlueprint::OnnxBlueprint()
: Blueprint("onnxModel"),
+ _cache_token(),
+ _debug_model(),
_model(nullptr),
_wire_info()
{
@@ -80,15 +82,18 @@ bool
OnnxBlueprint::setup(const IIndexEnvironment &env,
const ParameterList &params)
{
- auto optimize = (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP)
- ? Onnx::Optimize::DISABLE
- : Onnx::Optimize::ENABLE;
auto model_cfg = env.getOnnxModel(params[0].getValue());
if (!model_cfg) {
return fail("no model with name '%s' found", params[0].getValue().c_str());
}
try {
- _model = std::make_unique<Onnx>(model_cfg->file_path(), optimize);
+ if (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP) {
+ _debug_model = std::make_unique<Onnx>(model_cfg->file_path(), Optimize::DISABLE);
+ _model = _debug_model.get();
+ } else {
+ _cache_token = OnnxModelCache::load(model_cfg->file_path());
+ _model = &(_cache_token->get());
+ }
} catch (std::exception &ex) {
return fail("model setup failed: %s", ex.what());
}
@@ -132,7 +137,7 @@ OnnxBlueprint::setup(const IIndexEnvironment &env,
FeatureExecutor &
OnnxBlueprint::createExecutor(const IQueryEnvironment &, Stash &stash) const
{
- assert(_model);
+ assert(_model != nullptr);
return stash.create<OnnxFeatureExecutor>(*_model, _wire_info);
}
diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.h b/searchlib/src/vespa/searchlib/features/onnx_feature.h
index 6a63e7276c2..ed0fbc502f0 100644
--- a/searchlib/src/vespa/searchlib/features/onnx_feature.h
+++ b/searchlib/src/vespa/searchlib/features/onnx_feature.h
@@ -3,7 +3,7 @@
#pragma once
#include <vespa/searchlib/fef/blueprint.h>
-#include <vespa/eval/onnx/onnx_wrapper.h>
+#include <vespa/eval/onnx/onnx_model_cache.h>
namespace search::features {
@@ -13,7 +13,11 @@ namespace search::features {
class OnnxBlueprint : public fef::Blueprint {
private:
using Onnx = vespalib::eval::Onnx;
- std::unique_ptr<Onnx> _model;
+ using Optimize = vespalib::eval::Onnx::Optimize;
+ using OnnxModelCache = vespalib::eval::OnnxModelCache;
+ OnnxModelCache::Token::UP _cache_token;
+ std::unique_ptr<Onnx> _debug_model;
+ const Onnx *_model;
Onnx::WireInfo _wire_info;
public:
OnnxBlueprint();