summaryrefslogtreecommitdiffstats
path: root/eval/src/tests
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2018-02-05 13:20:32 +0000
committerHåvard Pettersen <havardpe@oath.com>2018-02-05 13:20:32 +0000
commit0053375eecf11d35bf8aeccd7ecbd28a4adf9292 (patch)
tree88067e05d593251c90ea699c40c361b61931c85f /eval/src/tests
parent5a77206ad7fbc7b32ff6988b29ac7fc6f9438b8c (diff)
allow serializing dense tensor views
needed for fall-back to reference implementation using on-the-fly generated dense tensors that are not of the 'DenseTensor' class.
Diffstat (limited to 'eval/src/tests')
-rw-r--r--eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp4
-rw-r--r--eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp11
2 files changed, 12 insertions, 3 deletions
diff --git a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
index abf51d57b9a..71bbacc7806 100644
--- a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
@@ -31,10 +31,10 @@ makeTensor(size_t numCells, double cellBias)
double
calcDotProduct(const DenseTensor &lhs, const DenseTensor &rhs)
{
- size_t numCells = std::min(lhs.cells().size(), rhs.cells().size());
+ size_t numCells = std::min(lhs.cellsRef().size(), rhs.cellsRef().size());
double result = 0;
for (size_t i = 0; i < numCells; ++i) {
- result += (lhs.cells()[i] * rhs.cells()[i]);
+ result += (lhs.cellsRef()[i] * rhs.cellsRef()[i]);
}
return result;
}
diff --git a/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp b/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp
index 61efdbe6d22..ae6166f9d24 100644
--- a/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp
+++ b/eval/src/tests/tensor/dense_tensor_builder/dense_tensor_builder_test.cpp
@@ -10,6 +10,15 @@ using vespalib::IllegalArgumentException;
using Builder = DenseTensorBuilder;
using vespalib::eval::TensorSpec;
using vespalib::eval::ValueType;
+using vespalib::ConstArrayRef;
+
+template <typename T> std::vector<T> make_vector(const ConstArrayRef<T> &ref) {
+ std::vector<T> vec;
+ for (const T &t: ref) {
+ vec.push_back(t);
+ }
+ return vec;
+}
void
assertTensor(const std::vector<ValueType::Dimension> &expDims,
@@ -18,7 +27,7 @@ assertTensor(const std::vector<ValueType::Dimension> &expDims,
{
const DenseTensor &realTensor = dynamic_cast<const DenseTensor &>(tensor);
EXPECT_EQUAL(ValueType::tensor_type(expDims), realTensor.type());
- EXPECT_EQUAL(expCells, realTensor.cells());
+ EXPECT_EQUAL(expCells, make_vector(realTensor.cellsRef()));
}
void