diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-11-30 12:51:04 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-11-30 12:55:13 +0000 |
commit | 450c0f9bd7eee971c8f63824cbb4f2b7a3a4d1c4 (patch) | |
tree | 21a42692a2cccfc7f995ef3285c041a6beb0486d /eval/src/tests/streamed | |
parent | 59af1d07ae7f4692cc75bfcad62648ca3e72e9df (diff) |
compare with reference operation
Diffstat (limited to 'eval/src/tests/streamed')
-rw-r--r-- | eval/src/tests/streamed/value/streamed_value_test.cpp | 16 |
1 files changed, 4 insertions, 12 deletions
diff --git a/eval/src/tests/streamed/value/streamed_value_test.cpp b/eval/src/tests/streamed/value/streamed_value_test.cpp index 3de6ba0fb63..05d6e20451c 100644 --- a/eval/src/tests/streamed/value/streamed_value_test.cpp +++ b/eval/src/tests/streamed/value/streamed_value_test.cpp @@ -1,6 +1,7 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/eval/streamed/streamed_value_builder_factory.h> +#include <vespa/eval/eval/test/reference_operations.h> #include <vespa/eval/eval/value_codec.h> #include <vespa/eval/instruction/generic_join.h> #include <vespa/eval/eval/interpreted_function.h> @@ -59,16 +60,7 @@ std::vector<Layout> join_layouts = { float_cells({x({"a","b","c"}),y(5)}), float_cells({y(5),z({"i","j","k","l"})}) }; -TensorSpec simple_tensor_join(const TensorSpec &a, const TensorSpec &b, join_fun_t function) { - Stash stash; - const auto &engine = SimpleTensorEngine::ref(); - auto lhs = engine.from_spec(a); - auto rhs = engine.from_spec(b); - const auto &result = engine.join(*lhs, *rhs, function, stash); - return engine.to_spec(result); -} - -TensorSpec streamed_value_new_join(const TensorSpec &a, const TensorSpec &b, join_fun_t function) { +TensorSpec streamed_value_join(const TensorSpec &a, const TensorSpec &b, join_fun_t function) { Stash stash; const auto &factory = StreamedValueBuilderFactory::get(); auto lhs = value_from_spec(a, factory); @@ -126,8 +118,8 @@ TEST(StreamedValueTest, new_generic_join_works_for_streamed_values) { TensorSpec rhs = spec(join_layouts[i + 1], Div16(N())); for (auto fun: {operation::Add::f, operation::Sub::f, operation::Mul::f, operation::Max::f}) { SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str())); - auto expect = simple_tensor_join(lhs, rhs, fun); - auto actual = streamed_value_new_join(lhs, rhs, fun); + auto expect = ReferenceOperations::join(lhs, rhs, fun); + auto actual = streamed_value_join(lhs, rhs, fun); EXPECT_EQ(actual, expect); } } |