summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp')
-rw-r--r--eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp61
1 files changed, 38 insertions, 23 deletions
diff --git a/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp b/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp
index a0b7b60e2da..3f4641ed2ee 100644
--- a/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp
+++ b/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp
@@ -2,31 +2,36 @@
#include <vespa/vespalib/test/insertion_operators.h>
#include <vespa/vespalib/testkit/test_kit.h>
-#include <vespa/eval/tensor/dense/direct_dense_tensor_builder.h>
+#include <vespa/eval/tensor/dense/typed_dense_tensor_builder.h>
#include <vespa/vespalib/util/exceptions.h>
using namespace vespalib::tensor;
using vespalib::IllegalArgumentException;
-using Builder = DirectDenseTensorBuilder;
+using BuilderDbl = TypedDenseTensorBuilder<double>;
+using BuilderFlt = TypedDenseTensorBuilder<float>;
using vespalib::eval::TensorSpec;
using vespalib::eval::ValueType;
using vespalib::ConstArrayRef;
-template <typename T> std::vector<T> make_vector(const ConstArrayRef<T> &ref) {
- std::vector<T> vec;
- for (const T &t: ref) {
- vec.push_back(t);
+struct CallMakeVector {
+ template <typename T>
+ static std::vector<double> call(const ConstArrayRef<T> &ref) {
+ std::vector<double> result;
+ result.reserve(ref.size());
+ for (T v : ref) {
+ result.push_back(v);
+ }
+ return result;
}
- return vec;
-}
+};
void assertTensor(const vespalib::string &type_spec,
- const DenseTensor::Cells &expCells,
+ const std::vector<double> &expCells,
const Tensor &tensor)
{
- const DenseTensor &realTensor = dynamic_cast<const DenseTensor &>(tensor);
+ const DenseTensorView &realTensor = dynamic_cast<const DenseTensorView &>(tensor);
EXPECT_EQUAL(ValueType::from_spec(type_spec), realTensor.type());
- EXPECT_EQUAL(expCells, make_vector(realTensor.cellsRef()));
+ EXPECT_EQUAL(expCells, dispatch_1<CallMakeVector>(realTensor.cellsRef()));
}
void assertTensorSpec(const TensorSpec &expSpec, const Tensor &tensor) {
@@ -35,7 +40,7 @@ void assertTensorSpec(const TensorSpec &expSpec, const Tensor &tensor) {
}
Tensor::UP build1DTensor() {
- Builder builder(ValueType::from_spec("tensor(x[3])"));
+ BuilderDbl builder(ValueType::from_spec("tensor(x[3])"));
builder.insertCell(0, 10);
builder.insertCell(1, 11);
builder.insertCell(2, 12);
@@ -55,7 +60,7 @@ TEST("require that 1d tensor can be converted to tensor spec") {
}
Tensor::UP build2DTensor() {
- Builder builder(ValueType::from_spec("tensor(x[3],y[2])"));
+ BuilderDbl builder(ValueType::from_spec("tensor(x[3],y[2])"));
builder.insertCell({0, 0}, 10);
builder.insertCell({0, 1}, 11);
builder.insertCell({1, 0}, 12);
@@ -81,7 +86,7 @@ TEST("require that 2d tensor can be converted to tensor spec") {
}
TEST("require that 3d tensor can be constructed") {
- Builder builder(ValueType::from_spec("tensor(x[3],y[2],z[2])"));
+ BuilderDbl builder(ValueType::from_spec("tensor(x[3],y[2],z[2])"));
builder.insertCell({0, 0, 0}, 10);
builder.insertCell({0, 0, 1}, 11);
builder.insertCell({0, 1, 0}, 12);
@@ -99,16 +104,26 @@ TEST("require that 3d tensor can be constructed") {
*builder.build());
}
+TEST("require that 2d tensor with float cells can be constructed") {
+ BuilderFlt builder(ValueType::from_spec("tensor<float>(x[3],y[2])"));
+ builder.insertCell({0, 1}, 2.5);
+ builder.insertCell({1, 0}, 1.5);
+ builder.insertCell({2, 0}, -0.25);
+ builder.insertCell({2, 1}, 0.75);
+ assertTensor("tensor<float>(x[3],y[2])", {0,2.5,1.5,0,-0.25,0.75},
+ *builder.build());
+}
+
TEST("require that cells get default value 0 if not specified") {
- Builder builder(ValueType::from_spec("tensor(x[3])"));
+ BuilderDbl builder(ValueType::from_spec("tensor(x[3])"));
builder.insertCell(1, 11);
assertTensor("tensor(x[3])", {0,11,0},
*builder.build());
}
-void assertTensorCell(const DenseTensor::Address &expAddress,
+void assertTensorCell(const DenseTensorView::Address &expAddress,
double expCell,
- const DenseTensor::CellsIterator &itr)
+ const DenseTensorView::CellsIterator &itr)
{
EXPECT_TRUE(itr.valid());
EXPECT_EQUAL(expAddress, itr.address());
@@ -118,14 +133,14 @@ void assertTensorCell(const DenseTensor::Address &expAddress,
TEST("require that dense tensor cells iterator works for 1d tensor") {
Tensor::UP tensor;
{
- Builder builder(ValueType::from_spec("tensor(x[2])"));
+ BuilderDbl builder(ValueType::from_spec("tensor(x[2])"));
builder.insertCell(0, 2);
builder.insertCell(1, 3);
tensor = builder.build();
}
- const DenseTensor &denseTensor = dynamic_cast<const DenseTensor &>(*tensor);
- DenseTensor::CellsIterator itr = denseTensor.cellsIterator();
+ const DenseTensorView &denseTensor = dynamic_cast<const DenseTensorView &>(*tensor);
+ DenseTensorView::CellsIterator itr = denseTensor.cellsIterator();
assertTensorCell({0}, 2, itr);
itr.next();
@@ -137,7 +152,7 @@ TEST("require that dense tensor cells iterator works for 1d tensor") {
TEST("require that dense tensor cells iterator works for 2d tensor") {
Tensor::UP tensor;
{
- Builder builder(ValueType::from_spec("tensor(x[2],y[2])"));
+ BuilderDbl builder(ValueType::from_spec("tensor(x[2],y[2])"));
builder.insertCell({0, 0}, 2);
builder.insertCell({0, 1}, 3);
builder.insertCell({1, 0}, 5);
@@ -145,8 +160,8 @@ TEST("require that dense tensor cells iterator works for 2d tensor") {
tensor = builder.build();
}
- const DenseTensor &denseTensor = dynamic_cast<const DenseTensor &>(*tensor);
- DenseTensor::CellsIterator itr = denseTensor.cellsIterator();
+ const DenseTensorView &denseTensor = dynamic_cast<const DenseTensorView &>(*tensor);
+ DenseTensorView::CellsIterator itr = denseTensor.cellsIterator();
assertTensorCell({0,0}, 2, itr);
itr.next();