diff options
author | Håvard Pettersen <havardpe@oath.com> | 2021-06-11 13:59:15 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2021-06-11 15:22:03 +0000 |
commit | 60642e69ea89ef4386763db5b3a54f212b9a557b (patch) | |
tree | e0ecff3caa48cc2c569b0b8efff38849741e7157 /eval | |
parent | f163dd2b355819e490fd8f2e0327a3a6950bb94a (diff) |
added onnx model cache
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp | 20 | ||||
-rw-r--r-- | eval/src/vespa/eval/onnx/CMakeLists.txt | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/onnx/onnx_model_cache.cpp | 51 | ||||
-rw-r--r-- | eval/src/vespa/eval/onnx/onnx_model_cache.h | 58 |
4 files changed, 130 insertions, 0 deletions
diff --git a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp index 54f958f8111..6b45172ef80 100644 --- a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp +++ b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp @@ -3,6 +3,7 @@ #include <vespa/eval/eval/tensor_spec.h> #include <vespa/eval/eval/int8float.h> #include <vespa/eval/onnx/onnx_wrapper.h> +#include <vespa/eval/onnx/onnx_model_cache.h> #include <vespa/vespalib/util/bfloat16.h> #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/gtest/gtest.h> @@ -443,4 +444,23 @@ TEST(OnnxTest, default_allocator_type) { fprintf(stderr, "default allocator type: %d\n", int(res)); } +TEST(OnnxModelCacheTest, share_and_evict_onnx_models) { + { + auto simple1 = OnnxModelCache::load(simple_model); + auto simple2 = OnnxModelCache::load(simple_model); + auto dynamic1 = OnnxModelCache::load(dynamic_model); + auto dynamic2 = OnnxModelCache::load(dynamic_model); + auto dynamic3 = OnnxModelCache::load(dynamic_model); + EXPECT_EQ(simple1->get().inputs().size(), 3); + EXPECT_EQ(dynamic1->get().inputs().size(), 3); + EXPECT_EQ(&(simple1->get()), &(simple2->get())); + EXPECT_EQ(&(dynamic1->get()), &(dynamic2->get())); + EXPECT_EQ(&(dynamic2->get()), &(dynamic3->get())); + EXPECT_EQ(OnnxModelCache::num_cached(), 2); + EXPECT_EQ(OnnxModelCache::count_refs(), 5); + } + EXPECT_EQ(OnnxModelCache::num_cached(), 0); + EXPECT_EQ(OnnxModelCache::count_refs(), 0); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/eval/src/vespa/eval/onnx/CMakeLists.txt b/eval/src/vespa/eval/onnx/CMakeLists.txt index 9b18557c036..40444936d02 100644 --- a/eval/src/vespa/eval/onnx/CMakeLists.txt +++ b/eval/src/vespa/eval/onnx/CMakeLists.txt @@ -2,5 +2,6 @@ vespa_add_library(eval_onnx OBJECT SOURCES + onnx_model_cache.cpp onnx_wrapper.cpp ) diff --git a/eval/src/vespa/eval/onnx/onnx_model_cache.cpp b/eval/src/vespa/eval/onnx/onnx_model_cache.cpp new file mode 100644 index 00000000000..01d5fdd9c84 --- /dev/null +++ b/eval/src/vespa/eval/onnx/onnx_model_cache.cpp @@ -0,0 +1,51 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "onnx_model_cache.h" + +namespace vespalib::eval { + +std::mutex OnnxModelCache::_lock{}; +OnnxModelCache::Map OnnxModelCache::_cached{}; + +void +OnnxModelCache::release(Map::iterator entry) +{ + std::lock_guard<std::mutex> guard(_lock); + if (--(entry->second.num_refs) == 0) { + _cached.erase(entry); + } +} + +OnnxModelCache::Token::UP +OnnxModelCache::load(const vespalib::string &model_file) +{ + std::lock_guard<std::mutex> guard(_lock); + auto pos = _cached.find(model_file); + if (pos == _cached.end()) { + auto model = std::make_unique<Onnx>(model_file, Onnx::Optimize::ENABLE); + auto res = _cached.emplace(model_file, std::move(model)); + assert(res.second); + pos = res.first; + } + return std::make_unique<Token>(pos, ctor_tag()); +} + +size_t +OnnxModelCache::num_cached() +{ + std::lock_guard<std::mutex> guard(_lock); + return _cached.size(); +} + +size_t +OnnxModelCache::count_refs() +{ + std::lock_guard<std::mutex> guard(_lock); + size_t refs = 0; + for (const auto &entry: _cached) { + refs += entry.second.num_refs; + } + return refs; +} + +} diff --git a/eval/src/vespa/eval/onnx/onnx_model_cache.h b/eval/src/vespa/eval/onnx/onnx_model_cache.h new file mode 100644 index 00000000000..35d5fefa061 --- /dev/null +++ b/eval/src/vespa/eval/onnx/onnx_model_cache.h @@ -0,0 +1,58 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "onnx_wrapper.h" +#include <vespa/vespalib/stllike/string.h> +#include <memory> +#include <mutex> +#include <map> + +namespace vespalib::eval { + +/** + * Cache used to share loaded onnx models between users. The cache + * itself will not keep anything alive, but will let you find loaded + * models that are currently in use by others. + **/ +class OnnxModelCache +{ +private: + struct ctor_tag {}; + using Key = vespalib::string; + struct Value { + size_t num_refs; + std::unique_ptr<Onnx> model; + Value(std::unique_ptr<Onnx> model_in) : num_refs(0), model(std::move(model_in)) {} + const Onnx &get() { return *model; } + }; + using Map = std::map<Key,Value>; + static std::mutex _lock; + static Map _cached; + + static void release(Map::iterator entry); + +public: + class Token + { + private: + OnnxModelCache::Map::iterator _entry; + public: + Token(Token &&) = delete; + Token(const Token &) = delete; + Token &operator=(Token &&) = delete; + Token &operator=(const Token &) = delete; + using UP = std::unique_ptr<Token>; + explicit Token(OnnxModelCache::Map::iterator entry, ctor_tag) : _entry(entry) { + ++_entry->second.num_refs; + } + const Onnx &get() const { return _entry->second.get(); } + ~Token() { OnnxModelCache::release(_entry); } + }; + + static Token::UP load(const vespalib::string &model_file); + static size_t num_cached(); + static size_t count_refs(); +}; + +} |