summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/instruction/generic_concat/generic_concat_test.cpp')
-rw-r--r--eval/src/tests/instruction/generic_concat/generic_concat_test.cpp64
1 files changed, 4 insertions, 60 deletions
diff --git a/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp b/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp
index aaea8fdcb28..c59d9783648 100644
--- a/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp
+++ b/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp
@@ -8,6 +8,7 @@
#include <vespa/eval/eval/value_codec.h>
#include <vespa/eval/instruction/generic_concat.h>
#include <vespa/eval/eval/interpreted_function.h>
+#include <vespa/eval/eval/test/reference_operations.h>
#include <vespa/eval/eval/test/tensor_model.hpp>
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/gtest/gtest.h>
@@ -64,63 +65,6 @@ TensorSpec perform_simpletensor_concat(const TensorSpec &a, const TensorSpec &b,
return SimpleTensorEngine::ref().to_spec(*out);
}
-bool concat_address(const TensorSpec::Address &me, const TensorSpec::Address &other,
- const std::string &concat_dim, size_t my_offset,
- TensorSpec::Address &my_out, TensorSpec::Address &other_out)
-{
- my_out.insert_or_assign(concat_dim, my_offset);
- for (const auto &my_dim: me) {
- const auto & name = my_dim.first;
- const auto & label = my_dim.second;
- if (name == concat_dim) {
- my_out.insert_or_assign(name, label.index + my_offset);
- } else {
- auto pos = other.find(name);
- if ((pos == other.end()) || (pos->second == label)) {
- my_out.insert_or_assign(name, label);
- other_out.insert_or_assign(name, label);
- } else {
- return false;
- }
- }
- }
- return true;
-}
-
-bool concat_addresses(const TensorSpec::Address &a, const TensorSpec::Address &b,
- const std::string &concat_dim, size_t b_offset,
- TensorSpec::Address &a_out, TensorSpec::Address &b_out)
-{
- return concat_address(a, b, concat_dim, 0, a_out, b_out) &&
- concat_address(b, a, concat_dim, b_offset, b_out, a_out);
-}
-
-TensorSpec reference_concat(const TensorSpec &a, const TensorSpec &b, const std::string &concat_dim) {
- ValueType a_type = ValueType::from_spec(a.type());
- ValueType b_type = ValueType::from_spec(b.type());
- ValueType res_type = ValueType::concat(a_type, b_type, concat_dim);
- EXPECT_FALSE(res_type.is_error());
- size_t b_offset = 1;
- size_t concat_dim_index = a_type.dimension_index(concat_dim);
- if (concat_dim_index != ValueType::Dimension::npos) {
- const auto &dim = a_type.dimensions()[concat_dim_index];
- EXPECT_TRUE(dim.is_indexed());
- b_offset = dim.size;
- }
- TensorSpec result(res_type.to_spec());
- for (const auto &cell_a: a.cells()) {
- for (const auto &cell_b: b.cells()) {
- TensorSpec::Address addr_a;
- TensorSpec::Address addr_b;
- if (concat_addresses(cell_a.first, cell_b.first, concat_dim, b_offset, addr_a, addr_b)) {
- result.add(addr_a, cell_a.second);
- result.add(addr_b, cell_b.second);
- }
- }
- }
- return result;
-}
-
TensorSpec perform_generic_concat(const TensorSpec &a, const TensorSpec &b,
const std::string &concat_dim, const ValueBuilderFactory &factory)
{
@@ -138,7 +82,7 @@ TEST(GenericConcatTest, generic_reference_concat_works) {
const TensorSpec lhs = spec(concat_layouts[i], N());
const TensorSpec rhs = spec(concat_layouts[i + 1], Div16(N()));
SCOPED_TRACE(fmt("\n===\nin LHS: %s\nin RHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str()));
- auto actual = reference_concat(lhs, rhs, "y");
+ auto actual = ReferenceOperations::concat(lhs, rhs, "y");
auto expect = perform_simpletensor_concat(lhs, rhs, "y");
EXPECT_EQ(actual, expect);
}
@@ -151,7 +95,7 @@ void test_generic_concat_with(const ValueBuilderFactory &factory) {
const TensorSpec rhs = spec(concat_layouts[i + 1], Div16(N()));
SCOPED_TRACE(fmt("\n===\nin LHS: %s\nin RHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str()));
auto actual = perform_generic_concat(lhs, rhs, "y", factory);
- auto expect = reference_concat(lhs, rhs, "y");
+ auto expect = ReferenceOperations::concat(lhs, rhs, "y");
EXPECT_EQ(actual, expect);
}
}
@@ -202,7 +146,7 @@ TEST(GenericConcatTest, immediate_generic_concat_works) {
const TensorSpec rhs = spec(concat_layouts[i + 1], Div16(N()));
SCOPED_TRACE(fmt("\n===\nin LHS: %s\nin RHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str()));
auto actual = immediate_generic_concat(lhs, rhs, "y");
- auto expect = reference_concat(lhs, rhs, "y");
+ auto expect = ReferenceOperations::concat(lhs, rhs, "y");
EXPECT_EQ(actual, expect);
}
}