summaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp')
-rw-r--r--searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp89
1 files changed, 89 insertions, 0 deletions
diff --git a/searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp b/searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp
new file mode 100644
index 00000000000..1003e461676
--- /dev/null
+++ b/searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp
@@ -0,0 +1,89 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/searchlib/tensor/direct_tensor_store.h>
+#include <vespa/vespalib/gtest/gtest.h>
+#include <vespa/eval/tensor/default_tensor_engine.h>
+#include <vespa/eval/tensor/tensor.h>
+
+using namespace search::tensor;
+
+using vespalib::datastore::EntryRef;
+using vespalib::eval::TensorSpec;
+using vespalib::tensor::DefaultTensorEngine;
+using vespalib::tensor::Tensor;
+
+vespalib::string tensor_spec("tensor(x{})");
+
+Tensor::UP
+make_tensor(const TensorSpec& spec)
+{
+ auto value = DefaultTensorEngine::ref().from_spec(spec);
+ auto* tensor = dynamic_cast<Tensor*>(value.get());
+ assert(tensor != nullptr);
+ value.release();
+ return Tensor::UP(tensor);
+}
+
+Tensor::UP
+make_tensor(double value)
+{
+ return make_tensor(TensorSpec(tensor_spec).add({{"x", "a"}}, value));
+}
+
+class DirectTensorStoreTest : public ::testing::Test {
+public:
+ DirectTensorStore store;
+
+ DirectTensorStoreTest() : store() {}
+
+ virtual ~DirectTensorStoreTest() {
+ store.clearHoldLists();
+ }
+
+ void expect_tensor(const Tensor* exp, EntryRef ref) {
+ const auto* act = store.get_tensor(ref);
+ ASSERT_TRUE(act);
+ EXPECT_EQ(exp, act);
+ }
+};
+
+TEST_F(DirectTensorStoreTest, can_set_and_get_tensor)
+{
+ auto t = make_tensor(5);
+ auto* exp = t.get();
+ auto ref = store.set_tensor(std::move(t));
+ expect_tensor(exp, ref);
+}
+
+TEST_F(DirectTensorStoreTest, invalid_ref_returns_nullptr)
+{
+ const auto* t = store.get_tensor(EntryRef());
+ EXPECT_FALSE(t);
+}
+
+TEST_F(DirectTensorStoreTest, hold_adds_entry_to_hold_list)
+{
+ auto ref = store.set_tensor(make_tensor(5));
+ auto mem_1 = store.getMemoryUsage();
+ store.holdTensor(ref);
+ auto mem_2 = store.getMemoryUsage();
+ EXPECT_GT(mem_2.allocatedBytesOnHold(), mem_1.allocatedBytesOnHold());
+}
+
+TEST_F(DirectTensorStoreTest, move_allocates_new_entry_and_puts_old_entry_on_hold)
+{
+ auto t = make_tensor(5);
+ auto* exp = t.get();
+ auto ref_1 = store.set_tensor(std::move(t));
+ auto mem_1 = store.getMemoryUsage();
+
+ auto ref_2 = store.move(ref_1);
+ auto mem_2 = store.getMemoryUsage();
+ EXPECT_NE(ref_1, ref_2);
+ expect_tensor(exp, ref_1);
+ expect_tensor(exp, ref_2);
+ EXPECT_GT(mem_2.allocatedBytesOnHold(), mem_1.allocatedBytesOnHold());
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
+