blob: 2e3ea5482dab0bc7dfc6f489c750d81600202f21 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
|
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "onnx_model_cache.h"
namespace vespalib::eval {
std::mutex OnnxModelCache::_lock{};
OnnxModelCache::Map OnnxModelCache::_cached{};
void
OnnxModelCache::release(Map::iterator entry)
{
std::lock_guard<std::mutex> guard(_lock);
if (--(entry->second.num_refs) == 0) {
_cached.erase(entry);
}
}
OnnxModelCache::Token::UP
OnnxModelCache::load(const vespalib::string &model_file)
{
std::lock_guard<std::mutex> guard(_lock);
auto pos = _cached.find(model_file);
if (pos == _cached.end()) {
auto model = std::make_unique<Onnx>(model_file, Onnx::Optimize::ENABLE);
auto res = _cached.emplace(model_file, std::move(model));
assert(res.second);
pos = res.first;
}
return std::make_unique<Token>(pos, ctor_tag());
}
size_t
OnnxModelCache::num_cached()
{
std::lock_guard<std::mutex> guard(_lock);
return _cached.size();
}
size_t
OnnxModelCache::count_refs()
{
std::lock_guard<std::mutex> guard(_lock);
size_t refs = 0;
for (const auto &entry: _cached) {
refs += entry.second.num_refs;
}
return refs;
}
}
|