aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/tensor
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2020-08-31 13:56:19 +0000
committerGeir Storli <geirst@verizonmedia.com>2020-08-31 13:58:43 +0000
commite35cb626e2825376f4c88d62d5466bd8c205a8d4 (patch)
treecc6b8737fd9d1ee3433278c41596146063722f69 /searchlib/src/tests/tensor
parent078b0dcd710adc8bd53add3fb482b920d8806a93 (diff)
Implement store for heap allocated tensors.
Diffstat (limited to 'searchlib/src/tests/tensor')
-rw-r--r--searchlib/src/tests/tensor/direct_tensor_store/CMakeLists.txt9
-rw-r--r--searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp93
2 files changed, 102 insertions, 0 deletions
diff --git a/searchlib/src/tests/tensor/direct_tensor_store/CMakeLists.txt b/searchlib/src/tests/tensor/direct_tensor_store/CMakeLists.txt
new file mode 100644
index 00000000000..14a70f25e3c
--- /dev/null
+++ b/searchlib/src/tests/tensor/direct_tensor_store/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(searchlib_direct_tensor_store_test_app TEST
+ SOURCES
+ direct_tensor_store_test.cpp
+ DEPENDS
+ searchlib
+ GTest::GTest
+)
+vespa_add_test(NAME searchlib_direct_tensor_store_test_app COMMAND searchlib_direct_tensor_store_test_app)
diff --git a/searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp b/searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp
new file mode 100644
index 00000000000..763d5c0cbd5
--- /dev/null
+++ b/searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp
@@ -0,0 +1,93 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/searchlib/tensor/direct_tensor_store.h>
+#include <vespa/vespalib/gtest/gtest.h>
+#include <vespa/eval/tensor/default_tensor_engine.h>
+#include <vespa/eval/tensor/tensor.h>
+
+#include <vespa/log/log.h>
+LOG_SETUP("direct_tensor_store_test");
+
+using namespace search::tensor;
+
+using vespalib::datastore::EntryRef;
+using vespalib::eval::TensorSpec;
+using vespalib::tensor::DefaultTensorEngine;
+using vespalib::tensor::Tensor;
+
+vespalib::string tensor_spec("tensor(x{})");
+
+Tensor::UP
+make_tensor(const TensorSpec& spec)
+{
+ auto value = DefaultTensorEngine::ref().from_spec(spec);
+ auto* tensor = dynamic_cast<Tensor*>(value.get());
+ assert(tensor != nullptr);
+ value.release();
+ return Tensor::UP(tensor);
+}
+
+Tensor::UP
+make_tensor(double value)
+{
+ return make_tensor(TensorSpec(tensor_spec).add({{"x", "a"}}, value));
+}
+
+class DirectTensorStoreTest : public ::testing::Test {
+public:
+ DirectTensorStore store;
+
+ DirectTensorStoreTest() : store() {}
+
+ virtual ~DirectTensorStoreTest() {
+ store.clearHoldLists();
+ }
+
+ void expect_tensor(const Tensor* exp, EntryRef ref) {
+ const auto* act = store.get_tensor(ref);
+ ASSERT_TRUE(act);
+ EXPECT_EQ(exp, act);
+ EXPECT_EQ(*exp, *act);
+ }
+};
+
+TEST_F(DirectTensorStoreTest, can_set_and_get_tensor)
+{
+ auto t = make_tensor(5);
+ auto* exp = t.get();
+ auto ref = store.set_tensor(std::move(t));
+ expect_tensor(exp, ref);
+}
+
+TEST_F(DirectTensorStoreTest, invalid_ref_returns_nullptr)
+{
+ const auto* t = store.get_tensor(EntryRef());
+ EXPECT_FALSE(t);
+}
+
+TEST_F(DirectTensorStoreTest, hold_adds_entry_to_hold_list)
+{
+ auto ref = store.set_tensor(make_tensor(5));
+ auto mem_1 = store.getMemoryUsage();
+ store.holdTensor(ref);
+ auto mem_2 = store.getMemoryUsage();
+ EXPECT_GT(mem_2.allocatedBytesOnHold(), mem_1.allocatedBytesOnHold());
+}
+
+TEST_F(DirectTensorStoreTest, move_allocates_new_entry_and_puts_old_entry_on_hold)
+{
+ auto t = make_tensor(5);
+ auto* exp = t.get();
+ auto ref_1 = store.set_tensor(std::move(t));
+ auto mem_1 = store.getMemoryUsage();
+
+ auto ref_2 = store.move(ref_1);
+ auto mem_2 = store.getMemoryUsage();
+ EXPECT_NE(ref_1, ref_2);
+ expect_tensor(exp, ref_1);
+ expect_tensor(exp, ref_2);
+ EXPECT_GT(mem_2.allocatedBytesOnHold(), mem_1.allocatedBytesOnHold());
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
+