summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/instruction/generic_merge
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-10-22 07:23:00 +0000
committerArne Juul <arnej@verizonmedia.com>2020-10-22 07:45:50 +0000
commitcb62cd0e86a6bf708efa8212930c6ce99f5e6aac (patch)
treed14f4978b509c2736c63a0cc90c9a2dc574e3f74 /eval/src/tests/instruction/generic_merge
parenteeed85a0c98c88c1de65ff7e821025f29fc0347e (diff)
add unit tests with FastValueBuilderFactory also
Diffstat (limited to 'eval/src/tests/instruction/generic_merge')
-rw-r--r--eval/src/tests/instruction/generic_merge/generic_merge_test.cpp16
1 files changed, 12 insertions, 4 deletions
diff --git a/eval/src/tests/instruction/generic_merge/generic_merge_test.cpp b/eval/src/tests/instruction/generic_merge/generic_merge_test.cpp
index fd9b8513acb..5166ef6ccc9 100644
--- a/eval/src/tests/instruction/generic_merge/generic_merge_test.cpp
+++ b/eval/src/tests/instruction/generic_merge/generic_merge_test.cpp
@@ -1,6 +1,7 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include <vespa/eval/eval/simple_value.h>
+#include <vespa/eval/eval/fast_value.h>
#include <vespa/eval/eval/value_codec.h>
#include <vespa/eval/instruction/generic_merge.h>
#include <vespa/eval/eval/interpreted_function.h>
@@ -55,9 +56,8 @@ TensorSpec reference_merge(const TensorSpec &a, const TensorSpec &b, join_fun_t
return result;
}
-TensorSpec perform_generic_merge(const TensorSpec &a, const TensorSpec &b, join_fun_t fun) {
+TensorSpec perform_generic_merge(const TensorSpec &a, const TensorSpec &b, join_fun_t fun, const ValueBuilderFactory &factory) {
Stash stash;
- const auto &factory = SimpleValueBuilderFactory::get();
auto lhs = value_from_spec(a, factory);
auto rhs = value_from_spec(b, factory);
auto my_op = GenericMerge::make_instruction(lhs->type(), rhs->type(), fun, factory, stash);
@@ -65,7 +65,7 @@ TensorSpec perform_generic_merge(const TensorSpec &a, const TensorSpec &b, join_
return spec_from_value(single.eval(std::vector<Value::CREF>({*lhs, *rhs})));
}
-TEST(GenericMergeTest, generic_merge_works_for_simple_values) {
+void test_generic_merge_with(const ValueBuilderFactory &factory) {
ASSERT_TRUE((merge_layouts.size() % 2) == 0);
for (size_t i = 0; i < merge_layouts.size(); i += 2) {
TensorSpec lhs = spec(merge_layouts[i], N());
@@ -73,12 +73,20 @@ TEST(GenericMergeTest, generic_merge_works_for_simple_values) {
SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str()));
for (auto fun: {operation::Add::f, operation::Mul::f, operation::Sub::f, operation::Max::f}) {
auto expect = reference_merge(lhs, rhs, fun);
- auto actual = perform_generic_merge(lhs, rhs, fun);
+ auto actual = perform_generic_merge(lhs, rhs, fun, factory);
EXPECT_EQ(actual, expect);
}
}
}
+TEST(GenericMergeTest, generic_merge_works_for_simple_values) {
+ test_generic_merge_with(SimpleValueBuilderFactory::get());
+}
+
+TEST(GenericMergeTest, generic_merge_works_for_fast_values) {
+ test_generic_merge_with(FastValueBuilderFactory::get());
+}
+
TensorSpec immediate_generic_merge(const TensorSpec &a, const TensorSpec &b, join_fun_t fun) {
const auto &factory = SimpleValueBuilderFactory::get();
auto lhs = value_from_spec(a, factory);