diff options
Diffstat (limited to 'eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp')
-rw-r--r-- | eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp | 51 |
1 files changed, 36 insertions, 15 deletions
diff --git a/eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp b/eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp index ff59b28d60c..31f17b3eed2 100644 --- a/eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp +++ b/eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp @@ -1,5 +1,6 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include <vespa/eval/eval/operation.h> #include <vespa/eval/eval/tensor_spec.h> #include <vespa/eval/tensor/cell_values.h> #include <vespa/eval/tensor/default_tensor_engine.h> @@ -13,16 +14,6 @@ using vespalib::eval::TensorSpec; using vespalib::tensor::test::makeTensor; using namespace vespalib::tensor; -namespace { - -double -replace(double, double b) -{ - return b; -} - -} - void checkUpdate(const TensorSpec &source, const TensorSpec &update, const TensorSpec &expect) { @@ -30,7 +21,7 @@ checkUpdate(const TensorSpec &source, const TensorSpec &update, const TensorSpec auto updateTensor = makeTensor<SparseTensor>(update); const CellValues cellValues(*updateTensor); - auto actualTensor = sourceTensor->modify(replace, cellValues); + auto actualTensor = sourceTensor->modify(vespalib::eval::operation::Add::f, cellValues); auto actual = actualTensor->toSpec(); auto expectTensor = makeTensor<Tensor>(expect); auto expectPadded = expectTensor->toSpec(); @@ -45,7 +36,7 @@ TEST(TensorModifyTest, sparse_tensors_can_be_modified) TensorSpec("tensor(x{},y{})") .add({{"x","8"},{"y","9"}}, 2), TensorSpec("tensor(x{},y{})") - .add({{"x","8"},{"y","9"}}, 2) + .add({{"x","8"},{"y","9"}}, 13) .add({{"x","9"},{"y","9"}}, 11)); } @@ -57,10 +48,27 @@ TEST(TensorModifyTest, dense_tensors_can_be_modified) TensorSpec("tensor(x{},y{})") .add({{"x","8"},{"y","9"}}, 2), TensorSpec("tensor(x[10],y[10])") - .add({{"x",8},{"y",9}}, 2) + .add({{"x",8},{"y",9}}, 13) .add({{"x",9},{"y",9}}, 11)); } +TEST(TensorModifyTest, mixed_tensors_can_be_modified) +{ + checkUpdate(TensorSpec("tensor(x{},y[2])") + .add({{"x","a"},{"y",0}}, 2) + .add({{"x","a"},{"y",1}}, 3) + .add({{"x","b"},{"y",0}}, 4) + .add({{"x","b"},{"y",1}}, 5), + TensorSpec("tensor(x{},y{})") + .add({{"x","a"},{"y","0"}}, 6) + .add({{"x","b"},{"y","1"}}, 7), + TensorSpec("tensor(x{},y[2])") + .add({{"x","a"},{"y",0}}, 8) + .add({{"x","a"},{"y",1}}, 3) + .add({{"x","b"},{"y",0}}, 4) + .add({{"x","b"},{"y",1}}, 12)); +} + TEST(TensorModifyTest, sparse_tensors_ignore_updates_to_missing_cells) { checkUpdate(TensorSpec("tensor(x{},y{})") @@ -70,7 +78,7 @@ TEST(TensorModifyTest, sparse_tensors_ignore_updates_to_missing_cells) .add({{"x","7"},{"y","9"}}, 2) .add({{"x","8"},{"y","9"}}, 2), TensorSpec("tensor(x{},y{})") - .add({{"x","8"},{"y","9"}}, 2) + .add({{"x","8"},{"y","9"}}, 13) .add({{"x","9"},{"y","9"}}, 11)); } @@ -83,8 +91,21 @@ TEST(TensorModifyTest, dense_tensors_ignore_updates_to_out_of_range_cells) .add({{"x","8"},{"y","9"}}, 2) .add({{"x","10"},{"y","9"}}, 2), TensorSpec("tensor(x[10],y[10])") - .add({{"x",8},{"y",9}}, 2) + .add({{"x",8},{"y",9}}, 13) .add({{"x",9},{"y",9}}, 11)); } +TEST(TensorModifyTest, mixed_tensors_ignore_updates_to_missing_or_out_of_range_cells) +{ + checkUpdate(TensorSpec("tensor(x{},y[2])") + .add({{"x","a"},{"y",0}}, 2) + .add({{"x","a"},{"y",1}}, 3), + TensorSpec("tensor(x{},y{})") + .add({{"x","a"},{"y","2"}}, 4) + .add({{"x","c"},{"y","0"}}, 5), + TensorSpec("tensor(x{},y[2])") + .add({{"x","a"},{"y",0}}, 2) + .add({{"x","a"},{"y",1}}, 3)); +} + GTEST_MAIN_RUN_ALL_TESTS |