summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2021-06-11 13:59:15 +0000
committerHåvard Pettersen <havardpe@oath.com>2021-06-11 15:22:03 +0000
commit60642e69ea89ef4386763db5b3a54f212b9a557b (patch)
treee0ecff3caa48cc2c569b0b8efff38849741e7157 /eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
parentf163dd2b355819e490fd8f2e0327a3a6950bb94a (diff)
added onnx model cache
Diffstat (limited to 'eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp')
-rw-r--r--eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp20
1 files changed, 20 insertions, 0 deletions
diff --git a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
index 54f958f8111..6b45172ef80 100644
--- a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
+++ b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
@@ -3,6 +3,7 @@
#include <vespa/eval/eval/tensor_spec.h>
#include <vespa/eval/eval/int8float.h>
#include <vespa/eval/onnx/onnx_wrapper.h>
+#include <vespa/eval/onnx/onnx_model_cache.h>
#include <vespa/vespalib/util/bfloat16.h>
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/gtest/gtest.h>
@@ -443,4 +444,23 @@ TEST(OnnxTest, default_allocator_type) {
fprintf(stderr, "default allocator type: %d\n", int(res));
}
+TEST(OnnxModelCacheTest, share_and_evict_onnx_models) {
+ {
+ auto simple1 = OnnxModelCache::load(simple_model);
+ auto simple2 = OnnxModelCache::load(simple_model);
+ auto dynamic1 = OnnxModelCache::load(dynamic_model);
+ auto dynamic2 = OnnxModelCache::load(dynamic_model);
+ auto dynamic3 = OnnxModelCache::load(dynamic_model);
+ EXPECT_EQ(simple1->get().inputs().size(), 3);
+ EXPECT_EQ(dynamic1->get().inputs().size(), 3);
+ EXPECT_EQ(&(simple1->get()), &(simple2->get()));
+ EXPECT_EQ(&(dynamic1->get()), &(dynamic2->get()));
+ EXPECT_EQ(&(dynamic2->get()), &(dynamic3->get()));
+ EXPECT_EQ(OnnxModelCache::num_cached(), 2);
+ EXPECT_EQ(OnnxModelCache::count_refs(), 5);
+ }
+ EXPECT_EQ(OnnxModelCache::num_cached(), 0);
+ EXPECT_EQ(OnnxModelCache::count_refs(), 0);
+}
+
GTEST_MAIN_RUN_ALL_TESTS()