diff options
Diffstat (limited to 'eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp')
-rw-r--r-- | eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp | 61 |
1 files changed, 38 insertions, 23 deletions
diff --git a/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp b/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp index a0b7b60e2da..3f4641ed2ee 100644 --- a/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp +++ b/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp @@ -2,31 +2,36 @@ #include <vespa/vespalib/test/insertion_operators.h> #include <vespa/vespalib/testkit/test_kit.h> -#include <vespa/eval/tensor/dense/direct_dense_tensor_builder.h> +#include <vespa/eval/tensor/dense/typed_dense_tensor_builder.h> #include <vespa/vespalib/util/exceptions.h> using namespace vespalib::tensor; using vespalib::IllegalArgumentException; -using Builder = DirectDenseTensorBuilder; +using BuilderDbl = TypedDenseTensorBuilder<double>; +using BuilderFlt = TypedDenseTensorBuilder<float>; using vespalib::eval::TensorSpec; using vespalib::eval::ValueType; using vespalib::ConstArrayRef; -template <typename T> std::vector<T> make_vector(const ConstArrayRef<T> &ref) { - std::vector<T> vec; - for (const T &t: ref) { - vec.push_back(t); +struct CallMakeVector { + template <typename T> + static std::vector<double> call(const ConstArrayRef<T> &ref) { + std::vector<double> result; + result.reserve(ref.size()); + for (T v : ref) { + result.push_back(v); + } + return result; } - return vec; -} +}; void assertTensor(const vespalib::string &type_spec, - const DenseTensor::Cells &expCells, + const std::vector<double> &expCells, const Tensor &tensor) { - const DenseTensor &realTensor = dynamic_cast<const DenseTensor &>(tensor); + const DenseTensorView &realTensor = dynamic_cast<const DenseTensorView &>(tensor); EXPECT_EQUAL(ValueType::from_spec(type_spec), realTensor.type()); - EXPECT_EQUAL(expCells, make_vector(realTensor.cellsRef())); + EXPECT_EQUAL(expCells, dispatch_1<CallMakeVector>(realTensor.cellsRef())); } void assertTensorSpec(const TensorSpec &expSpec, const Tensor &tensor) { @@ -35,7 +40,7 @@ void assertTensorSpec(const TensorSpec &expSpec, const Tensor &tensor) { } Tensor::UP build1DTensor() { - Builder builder(ValueType::from_spec("tensor(x[3])")); + BuilderDbl builder(ValueType::from_spec("tensor(x[3])")); builder.insertCell(0, 10); builder.insertCell(1, 11); builder.insertCell(2, 12); @@ -55,7 +60,7 @@ TEST("require that 1d tensor can be converted to tensor spec") { } Tensor::UP build2DTensor() { - Builder builder(ValueType::from_spec("tensor(x[3],y[2])")); + BuilderDbl builder(ValueType::from_spec("tensor(x[3],y[2])")); builder.insertCell({0, 0}, 10); builder.insertCell({0, 1}, 11); builder.insertCell({1, 0}, 12); @@ -81,7 +86,7 @@ TEST("require that 2d tensor can be converted to tensor spec") { } TEST("require that 3d tensor can be constructed") { - Builder builder(ValueType::from_spec("tensor(x[3],y[2],z[2])")); + BuilderDbl builder(ValueType::from_spec("tensor(x[3],y[2],z[2])")); builder.insertCell({0, 0, 0}, 10); builder.insertCell({0, 0, 1}, 11); builder.insertCell({0, 1, 0}, 12); @@ -99,16 +104,26 @@ TEST("require that 3d tensor can be constructed") { *builder.build()); } +TEST("require that 2d tensor with float cells can be constructed") { + BuilderFlt builder(ValueType::from_spec("tensor<float>(x[3],y[2])")); + builder.insertCell({0, 1}, 2.5); + builder.insertCell({1, 0}, 1.5); + builder.insertCell({2, 0}, -0.25); + builder.insertCell({2, 1}, 0.75); + assertTensor("tensor<float>(x[3],y[2])", {0,2.5,1.5,0,-0.25,0.75}, + *builder.build()); +} + TEST("require that cells get default value 0 if not specified") { - Builder builder(ValueType::from_spec("tensor(x[3])")); + BuilderDbl builder(ValueType::from_spec("tensor(x[3])")); builder.insertCell(1, 11); assertTensor("tensor(x[3])", {0,11,0}, *builder.build()); } -void assertTensorCell(const DenseTensor::Address &expAddress, +void assertTensorCell(const DenseTensorView::Address &expAddress, double expCell, - const DenseTensor::CellsIterator &itr) + const DenseTensorView::CellsIterator &itr) { EXPECT_TRUE(itr.valid()); EXPECT_EQUAL(expAddress, itr.address()); @@ -118,14 +133,14 @@ void assertTensorCell(const DenseTensor::Address &expAddress, TEST("require that dense tensor cells iterator works for 1d tensor") { Tensor::UP tensor; { - Builder builder(ValueType::from_spec("tensor(x[2])")); + BuilderDbl builder(ValueType::from_spec("tensor(x[2])")); builder.insertCell(0, 2); builder.insertCell(1, 3); tensor = builder.build(); } - const DenseTensor &denseTensor = dynamic_cast<const DenseTensor &>(*tensor); - DenseTensor::CellsIterator itr = denseTensor.cellsIterator(); + const DenseTensorView &denseTensor = dynamic_cast<const DenseTensorView &>(*tensor); + DenseTensorView::CellsIterator itr = denseTensor.cellsIterator(); assertTensorCell({0}, 2, itr); itr.next(); @@ -137,7 +152,7 @@ TEST("require that dense tensor cells iterator works for 1d tensor") { TEST("require that dense tensor cells iterator works for 2d tensor") { Tensor::UP tensor; { - Builder builder(ValueType::from_spec("tensor(x[2],y[2])")); + BuilderDbl builder(ValueType::from_spec("tensor(x[2],y[2])")); builder.insertCell({0, 0}, 2); builder.insertCell({0, 1}, 3); builder.insertCell({1, 0}, 5); @@ -145,8 +160,8 @@ TEST("require that dense tensor cells iterator works for 2d tensor") { tensor = builder.build(); } - const DenseTensor &denseTensor = dynamic_cast<const DenseTensor &>(*tensor); - DenseTensor::CellsIterator itr = denseTensor.cellsIterator(); + const DenseTensorView &denseTensor = dynamic_cast<const DenseTensorView &>(*tensor); + DenseTensorView::CellsIterator itr = denseTensor.cellsIterator(); assertTensorCell({0,0}, 2, itr); itr.next(); |