aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/streamed
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-11-30 12:51:04 +0000
committerArne Juul <arnej@verizonmedia.com>2020-11-30 12:55:13 +0000
commit450c0f9bd7eee971c8f63824cbb4f2b7a3a4d1c4 (patch)
tree21a42692a2cccfc7f995ef3285c041a6beb0486d /eval/src/tests/streamed
parent59af1d07ae7f4692cc75bfcad62648ca3e72e9df (diff)
compare with reference operation
Diffstat (limited to 'eval/src/tests/streamed')
-rw-r--r--eval/src/tests/streamed/value/streamed_value_test.cpp16
1 files changed, 4 insertions, 12 deletions
diff --git a/eval/src/tests/streamed/value/streamed_value_test.cpp b/eval/src/tests/streamed/value/streamed_value_test.cpp
index 3de6ba0fb63..05d6e20451c 100644
--- a/eval/src/tests/streamed/value/streamed_value_test.cpp
+++ b/eval/src/tests/streamed/value/streamed_value_test.cpp
@@ -1,6 +1,7 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include <vespa/eval/streamed/streamed_value_builder_factory.h>
+#include <vespa/eval/eval/test/reference_operations.h>
#include <vespa/eval/eval/value_codec.h>
#include <vespa/eval/instruction/generic_join.h>
#include <vespa/eval/eval/interpreted_function.h>
@@ -59,16 +60,7 @@ std::vector<Layout> join_layouts = {
float_cells({x({"a","b","c"}),y(5)}), float_cells({y(5),z({"i","j","k","l"})})
};
-TensorSpec simple_tensor_join(const TensorSpec &a, const TensorSpec &b, join_fun_t function) {
- Stash stash;
- const auto &engine = SimpleTensorEngine::ref();
- auto lhs = engine.from_spec(a);
- auto rhs = engine.from_spec(b);
- const auto &result = engine.join(*lhs, *rhs, function, stash);
- return engine.to_spec(result);
-}
-
-TensorSpec streamed_value_new_join(const TensorSpec &a, const TensorSpec &b, join_fun_t function) {
+TensorSpec streamed_value_join(const TensorSpec &a, const TensorSpec &b, join_fun_t function) {
Stash stash;
const auto &factory = StreamedValueBuilderFactory::get();
auto lhs = value_from_spec(a, factory);
@@ -126,8 +118,8 @@ TEST(StreamedValueTest, new_generic_join_works_for_streamed_values) {
TensorSpec rhs = spec(join_layouts[i + 1], Div16(N()));
for (auto fun: {operation::Add::f, operation::Sub::f, operation::Mul::f, operation::Max::f}) {
SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str()));
- auto expect = simple_tensor_join(lhs, rhs, fun);
- auto actual = streamed_value_new_join(lhs, rhs, fun);
+ auto expect = ReferenceOperations::join(lhs, rhs, fun);
+ auto actual = streamed_value_join(lhs, rhs, fun);
EXPECT_EQ(actual, expect);
}
}