summaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/tensor/direct_tensor_store
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-10-13 14:11:28 +0000
committerArne Juul <arnej@verizonmedia.com>2020-10-15 08:18:59 +0000
commit08393e9e14635f1c6a6c84650c25023a0db7ed0b (patch)
tree48aae1605140fc6ff7d571084f345d33a3189c62 /searchlib/src/tests/tensor/direct_tensor_store
parent61eaea251e8cacd320ac10754ffd1513d8638043 (diff)
handle both engine- and factory-based tensors
* use EngineOrFactory::get() instead of DefaultTensorEngine::ref() * avoid direct use of DenseTensorView etc where possible * use eval::Value instead of tensor::Tensor where possible
Diffstat (limited to 'searchlib/src/tests/tensor/direct_tensor_store')
-rw-r--r--searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp22
1 files changed, 10 insertions, 12 deletions
diff --git a/searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp b/searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp
index 5a815a96dfb..4eb2c935234 100644
--- a/searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp
+++ b/searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp
@@ -2,30 +2,28 @@
#include <vespa/searchlib/tensor/direct_tensor_store.h>
#include <vespa/vespalib/gtest/gtest.h>
-#include <vespa/eval/tensor/default_tensor_engine.h>
-#include <vespa/eval/tensor/tensor.h>
+#include <vespa/eval/eval/engine_or_factory.h>
+#include <vespa/eval/eval/value.h>
#include <vespa/vespalib/datastore/datastore.hpp>
using namespace search::tensor;
using vespalib::datastore::EntryRef;
+using vespalib::eval::EngineOrFactory;
using vespalib::eval::TensorSpec;
-using vespalib::tensor::DefaultTensorEngine;
-using vespalib::tensor::Tensor;
+using vespalib::eval::Value;
vespalib::string tensor_spec("tensor(x{})");
-Tensor::UP
+Value::UP
make_tensor(const TensorSpec& spec)
{
- auto value = DefaultTensorEngine::ref().from_spec(spec);
- auto* tensor = dynamic_cast<Tensor*>(value.get());
- assert(tensor != nullptr);
- value.release();
- return Tensor::UP(tensor);
+ auto value = EngineOrFactory::get().from_spec(spec);
+ assert(value->is_tensor());
+ return value;
}
-Tensor::UP
+Value::UP
make_tensor(double value)
{
return make_tensor(TensorSpec(tensor_spec).add({{"x", "a"}}, value));
@@ -41,7 +39,7 @@ public:
store.clearHoldLists();
}
- void expect_tensor(const Tensor* exp, EntryRef ref) {
+ void expect_tensor(const Value* exp, EntryRef ref) {
const auto* act = store.get_tensor(ref);
ASSERT_TRUE(act);
EXPECT_EQ(exp, act);