summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp')
-rw-r--r--eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp71
1 files changed, 43 insertions, 28 deletions
diff --git a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
index e1967f012cb..bf9aeead461 100644
--- a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
+++ b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
@@ -15,6 +15,7 @@
#include <vespa/vespalib/util/benchmark_timer.h>
#include <vespa/vespalib/util/classname.h>
#include <vespa/vespalib/util/stringfmt.h>
+#include <vespa/vespalib/util/trinary.h>
#include <vespa/vespalib/gtest/gtest.h>
#include <optional>
@@ -115,7 +116,17 @@ Optimize universal_only() {
return Optimize::specific("universal_only", my_optimizer);
}
-void verify(const vespalib::string &expr, select_cell_type_t select_cell_type) {
+Trinary tri(bool value) {
+ return value ? Trinary::True : Trinary::False;
+}
+
+bool satisfies(bool actual, Trinary expect) {
+ return (expect == Trinary::Undefined) || (actual == (expect == Trinary::True));
+}
+
+void verify(const vespalib::string &expr, select_cell_type_t select_cell_type,
+ Trinary expect_forward, Trinary expect_distinct, Trinary expect_single)
+{
++verify_cnt;
auto fun = Function::parse(expr);
ASSERT_FALSE(fun->has_error());
@@ -134,10 +145,18 @@ void verify(const vespalib::string &expr, select_cell_type_t select_cell_type) {
const ValueType &expected_type = node_types.get_type(fun->root());
ASSERT_FALSE(expected_type.is_error());
Stash stash;
- size_t count = 0;
+ std::vector<const TensorFunction *> list;
const TensorFunction &plain_fun = make_tensor_function(prod_factory, fun->root(), node_types, stash);
- const TensorFunction &optimized = apply_tensor_function_optimizer(plain_fun, universal_only().optimizer, stash, &count);
- ASSERT_GT(count, 0);
+ const TensorFunction &optimized = apply_tensor_function_optimizer(plain_fun, universal_only().optimizer, stash,
+ [&list](const auto &node){
+ list.push_back(std::addressof(node));
+ });
+ ASSERT_EQ(list.size(), 1);
+ auto node = as<UniversalDotProduct>(*list[0]);
+ ASSERT_TRUE(node);
+ EXPECT_TRUE(satisfies(node->forward(), expect_forward));
+ EXPECT_TRUE(satisfies(node->distinct(), expect_distinct));
+ EXPECT_TRUE(satisfies(node->single(), expect_single));
InterpretedFunction ifun(prod_factory, optimized);
InterpretedFunction::Context ctx(ifun);
const Value &actual = ifun.eval(ctx, params);
@@ -152,7 +171,12 @@ void verify(const vespalib::string &expr, select_cell_type_t select_cell_type) {
auto expected = eval_ref(*fun, select_cell_type);
EXPECT_EQ(spec_from_value(actual), expected);
}
-void verify(const vespalib::string &expr) { verify(expr, always_double); }
+void verify(const vespalib::string &expr) {
+ verify(expr, always_double, Trinary::Undefined, Trinary::Undefined, Trinary::Undefined);
+}
+void verify(const vespalib::string &expr, select_cell_type_t select_cell_type, bool forward, bool distinct, bool single) {
+ verify(expr, select_cell_type, tri(forward), tri(distinct), tri(single));
+}
using cost_list_t = std::vector<std::pair<vespalib::string,double>>;
std::vector<std::pair<vespalib::string,cost_list_t>> benchmark_results;
@@ -192,8 +216,9 @@ void benchmark(const vespalib::string &expr, std::vector<Optimize> list) {
break;
case Optimize::With::SPECIFIC:
size_t count = 0;
- optimized = std::addressof(apply_tensor_function_optimizer(plain_fun, optimize.optimizer, stash, &count));
- ASSERT_GT(count, 0);
+ optimized = std::addressof(apply_tensor_function_optimizer(plain_fun, optimize.optimizer, stash,
+ [&count](const auto &)noexcept{ ++count; }));
+ ASSERT_EQ(count, 1);
break;
}
ASSERT_NE(optimized, nullptr);
@@ -255,36 +280,26 @@ TEST(UniversalDotProductTest, test_select_cell_types) {
}
TEST(UniversalDotProductTest, universal_dot_product_works_for_various_cases) {
- // forward, distinct, single
- verify("reduce(2.0*3.0, sum)");
+ // forward, distinct, single
+ verify("reduce(2.0*3.0, sum)", always_double, true, true, true);
for (CellType lct: CellTypeUtils::list_types()) {
for (CellType rct: CellTypeUtils::list_types()) {
auto sel2 = select(lct, rct);
- // !forward, !distinct, !single
- verify("reduce(a4_1x8*a2_1x8,sum,a,x)", sel2);
-
- // !forward, !distinct, single
- verify("reduce(a4_1x8*a2_1x8,sum,a)", sel2);
-
- // !forward, distinct, !single
- verify("reduce(a4_1x8*a2_1x8,sum,x)", sel2);
-
- // forward, !distinct, !single
- verify("reduce(a4_1x8*b2_1x8,sum,b,x)", sel2);
-
- // forward, !distinct, single
- verify("reduce(a4_1x8*b2_1x8,sum,b)", sel2);
-
- // forward, distinct, !single
- verify("reduce(a4_1x8*x8,sum,x)", sel2);
+ // forward, distinct, single
+ verify("reduce(a4_1x8*a2_1x8,sum,a,x)", sel2, false, false, false);
+ verify("reduce(a4_1x8*a2_1x8,sum,a)", sel2, false, false, true);
+ verify("reduce(a4_1x8*a2_1x8,sum,x)", sel2, false, true, false);
+ verify("reduce(a4_1x8*b2_1x8,sum,b,x)", sel2, true, false, false);
+ verify("reduce(a4_1x8*b2_1x8,sum,b)", sel2, true, false, true);
+ verify("reduce(a4_1x8*x8,sum,x)", sel2, true, true, false);
}
}
// !forward, distinct, single
-
+ //
// This case is not possible since 'distinct' implies '!single' as
// long as we reduce anything. The only expression allowed to
- // reduce nothing is the scalar case.
+ // reduce nothing is the scalar case, which satisfies 'forward'
}
TEST(UniversalDotProductTest, universal_dot_product_works_with_complex_dimension_nesting) {