summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2021-06-11 13:59:15 +0000
committerHåvard Pettersen <havardpe@oath.com>2021-06-11 15:22:03 +0000
commit60642e69ea89ef4386763db5b3a54f212b9a557b (patch)
treee0ecff3caa48cc2c569b0b8efff38849741e7157 /eval
parentf163dd2b355819e490fd8f2e0327a3a6950bb94a (diff)
added onnx model cache
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp20
-rw-r--r--eval/src/vespa/eval/onnx/CMakeLists.txt1
-rw-r--r--eval/src/vespa/eval/onnx/onnx_model_cache.cpp51
-rw-r--r--eval/src/vespa/eval/onnx/onnx_model_cache.h58
4 files changed, 130 insertions, 0 deletions
diff --git a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
index 54f958f8111..6b45172ef80 100644
--- a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
+++ b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
@@ -3,6 +3,7 @@
#include <vespa/eval/eval/tensor_spec.h>
#include <vespa/eval/eval/int8float.h>
#include <vespa/eval/onnx/onnx_wrapper.h>
+#include <vespa/eval/onnx/onnx_model_cache.h>
#include <vespa/vespalib/util/bfloat16.h>
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/gtest/gtest.h>
@@ -443,4 +444,23 @@ TEST(OnnxTest, default_allocator_type) {
fprintf(stderr, "default allocator type: %d\n", int(res));
}
+TEST(OnnxModelCacheTest, share_and_evict_onnx_models) {
+ {
+ auto simple1 = OnnxModelCache::load(simple_model);
+ auto simple2 = OnnxModelCache::load(simple_model);
+ auto dynamic1 = OnnxModelCache::load(dynamic_model);
+ auto dynamic2 = OnnxModelCache::load(dynamic_model);
+ auto dynamic3 = OnnxModelCache::load(dynamic_model);
+ EXPECT_EQ(simple1->get().inputs().size(), 3);
+ EXPECT_EQ(dynamic1->get().inputs().size(), 3);
+ EXPECT_EQ(&(simple1->get()), &(simple2->get()));
+ EXPECT_EQ(&(dynamic1->get()), &(dynamic2->get()));
+ EXPECT_EQ(&(dynamic2->get()), &(dynamic3->get()));
+ EXPECT_EQ(OnnxModelCache::num_cached(), 2);
+ EXPECT_EQ(OnnxModelCache::count_refs(), 5);
+ }
+ EXPECT_EQ(OnnxModelCache::num_cached(), 0);
+ EXPECT_EQ(OnnxModelCache::count_refs(), 0);
+}
+
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/eval/src/vespa/eval/onnx/CMakeLists.txt b/eval/src/vespa/eval/onnx/CMakeLists.txt
index 9b18557c036..40444936d02 100644
--- a/eval/src/vespa/eval/onnx/CMakeLists.txt
+++ b/eval/src/vespa/eval/onnx/CMakeLists.txt
@@ -2,5 +2,6 @@
vespa_add_library(eval_onnx OBJECT
SOURCES
+ onnx_model_cache.cpp
onnx_wrapper.cpp
)
diff --git a/eval/src/vespa/eval/onnx/onnx_model_cache.cpp b/eval/src/vespa/eval/onnx/onnx_model_cache.cpp
new file mode 100644
index 00000000000..01d5fdd9c84
--- /dev/null
+++ b/eval/src/vespa/eval/onnx/onnx_model_cache.cpp
@@ -0,0 +1,51 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "onnx_model_cache.h"
+
+namespace vespalib::eval {
+
+std::mutex OnnxModelCache::_lock{};
+OnnxModelCache::Map OnnxModelCache::_cached{};
+
+void
+OnnxModelCache::release(Map::iterator entry)
+{
+ std::lock_guard<std::mutex> guard(_lock);
+ if (--(entry->second.num_refs) == 0) {
+ _cached.erase(entry);
+ }
+}
+
+OnnxModelCache::Token::UP
+OnnxModelCache::load(const vespalib::string &model_file)
+{
+ std::lock_guard<std::mutex> guard(_lock);
+ auto pos = _cached.find(model_file);
+ if (pos == _cached.end()) {
+ auto model = std::make_unique<Onnx>(model_file, Onnx::Optimize::ENABLE);
+ auto res = _cached.emplace(model_file, std::move(model));
+ assert(res.second);
+ pos = res.first;
+ }
+ return std::make_unique<Token>(pos, ctor_tag());
+}
+
+size_t
+OnnxModelCache::num_cached()
+{
+ std::lock_guard<std::mutex> guard(_lock);
+ return _cached.size();
+}
+
+size_t
+OnnxModelCache::count_refs()
+{
+ std::lock_guard<std::mutex> guard(_lock);
+ size_t refs = 0;
+ for (const auto &entry: _cached) {
+ refs += entry.second.num_refs;
+ }
+ return refs;
+}
+
+}
diff --git a/eval/src/vespa/eval/onnx/onnx_model_cache.h b/eval/src/vespa/eval/onnx/onnx_model_cache.h
new file mode 100644
index 00000000000..35d5fefa061
--- /dev/null
+++ b/eval/src/vespa/eval/onnx/onnx_model_cache.h
@@ -0,0 +1,58 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "onnx_wrapper.h"
+#include <vespa/vespalib/stllike/string.h>
+#include <memory>
+#include <mutex>
+#include <map>
+
+namespace vespalib::eval {
+
+/**
+ * Cache used to share loaded onnx models between users. The cache
+ * itself will not keep anything alive, but will let you find loaded
+ * models that are currently in use by others.
+ **/
+class OnnxModelCache
+{
+private:
+ struct ctor_tag {};
+ using Key = vespalib::string;
+ struct Value {
+ size_t num_refs;
+ std::unique_ptr<Onnx> model;
+ Value(std::unique_ptr<Onnx> model_in) : num_refs(0), model(std::move(model_in)) {}
+ const Onnx &get() { return *model; }
+ };
+ using Map = std::map<Key,Value>;
+ static std::mutex _lock;
+ static Map _cached;
+
+ static void release(Map::iterator entry);
+
+public:
+ class Token
+ {
+ private:
+ OnnxModelCache::Map::iterator _entry;
+ public:
+ Token(Token &&) = delete;
+ Token(const Token &) = delete;
+ Token &operator=(Token &&) = delete;
+ Token &operator=(const Token &) = delete;
+ using UP = std::unique_ptr<Token>;
+ explicit Token(OnnxModelCache::Map::iterator entry, ctor_tag) : _entry(entry) {
+ ++_entry->second.num_refs;
+ }
+ const Onnx &get() const { return _entry->second.get(); }
+ ~Token() { OnnxModelCache::release(_entry); }
+ };
+
+ static Token::UP load(const vespalib::string &model_file);
+ static size_t num_cached();
+ static size_t count_refs();
+};
+
+}