diff options
Diffstat (limited to 'eval/src/tests/instruction/generic_concat/generic_concat_test.cpp')
-rw-r--r-- | eval/src/tests/instruction/generic_concat/generic_concat_test.cpp | 64 |
1 files changed, 4 insertions, 60 deletions
diff --git a/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp b/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp index aaea8fdcb28..c59d9783648 100644 --- a/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp +++ b/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp @@ -8,6 +8,7 @@ #include <vespa/eval/eval/value_codec.h> #include <vespa/eval/instruction/generic_concat.h> #include <vespa/eval/eval/interpreted_function.h> +#include <vespa/eval/eval/test/reference_operations.h> #include <vespa/eval/eval/test/tensor_model.hpp> #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/gtest/gtest.h> @@ -64,63 +65,6 @@ TensorSpec perform_simpletensor_concat(const TensorSpec &a, const TensorSpec &b, return SimpleTensorEngine::ref().to_spec(*out); } -bool concat_address(const TensorSpec::Address &me, const TensorSpec::Address &other, - const std::string &concat_dim, size_t my_offset, - TensorSpec::Address &my_out, TensorSpec::Address &other_out) -{ - my_out.insert_or_assign(concat_dim, my_offset); - for (const auto &my_dim: me) { - const auto & name = my_dim.first; - const auto & label = my_dim.second; - if (name == concat_dim) { - my_out.insert_or_assign(name, label.index + my_offset); - } else { - auto pos = other.find(name); - if ((pos == other.end()) || (pos->second == label)) { - my_out.insert_or_assign(name, label); - other_out.insert_or_assign(name, label); - } else { - return false; - } - } - } - return true; -} - -bool concat_addresses(const TensorSpec::Address &a, const TensorSpec::Address &b, - const std::string &concat_dim, size_t b_offset, - TensorSpec::Address &a_out, TensorSpec::Address &b_out) -{ - return concat_address(a, b, concat_dim, 0, a_out, b_out) && - concat_address(b, a, concat_dim, b_offset, b_out, a_out); -} - -TensorSpec reference_concat(const TensorSpec &a, const TensorSpec &b, const std::string &concat_dim) { - ValueType a_type = ValueType::from_spec(a.type()); - ValueType b_type = ValueType::from_spec(b.type()); - ValueType res_type = ValueType::concat(a_type, b_type, concat_dim); - EXPECT_FALSE(res_type.is_error()); - size_t b_offset = 1; - size_t concat_dim_index = a_type.dimension_index(concat_dim); - if (concat_dim_index != ValueType::Dimension::npos) { - const auto &dim = a_type.dimensions()[concat_dim_index]; - EXPECT_TRUE(dim.is_indexed()); - b_offset = dim.size; - } - TensorSpec result(res_type.to_spec()); - for (const auto &cell_a: a.cells()) { - for (const auto &cell_b: b.cells()) { - TensorSpec::Address addr_a; - TensorSpec::Address addr_b; - if (concat_addresses(cell_a.first, cell_b.first, concat_dim, b_offset, addr_a, addr_b)) { - result.add(addr_a, cell_a.second); - result.add(addr_b, cell_b.second); - } - } - } - return result; -} - TensorSpec perform_generic_concat(const TensorSpec &a, const TensorSpec &b, const std::string &concat_dim, const ValueBuilderFactory &factory) { @@ -138,7 +82,7 @@ TEST(GenericConcatTest, generic_reference_concat_works) { const TensorSpec lhs = spec(concat_layouts[i], N()); const TensorSpec rhs = spec(concat_layouts[i + 1], Div16(N())); SCOPED_TRACE(fmt("\n===\nin LHS: %s\nin RHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str())); - auto actual = reference_concat(lhs, rhs, "y"); + auto actual = ReferenceOperations::concat(lhs, rhs, "y"); auto expect = perform_simpletensor_concat(lhs, rhs, "y"); EXPECT_EQ(actual, expect); } @@ -151,7 +95,7 @@ void test_generic_concat_with(const ValueBuilderFactory &factory) { const TensorSpec rhs = spec(concat_layouts[i + 1], Div16(N())); SCOPED_TRACE(fmt("\n===\nin LHS: %s\nin RHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str())); auto actual = perform_generic_concat(lhs, rhs, "y", factory); - auto expect = reference_concat(lhs, rhs, "y"); + auto expect = ReferenceOperations::concat(lhs, rhs, "y"); EXPECT_EQ(actual, expect); } } @@ -202,7 +146,7 @@ TEST(GenericConcatTest, immediate_generic_concat_works) { const TensorSpec rhs = spec(concat_layouts[i + 1], Div16(N())); SCOPED_TRACE(fmt("\n===\nin LHS: %s\nin RHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str())); auto actual = immediate_generic_concat(lhs, rhs, "y"); - auto expect = reference_concat(lhs, rhs, "y"); + auto expect = ReferenceOperations::concat(lhs, rhs, "y"); EXPECT_EQ(actual, expect); } } |