From 7a3a53ec4e11eff07cf425f368e5faed477e5fb5 Mon Sep 17 00:00:00 2001 From: HÃ¥vard Pettersen Date: Tue, 12 Sep 2023 12:15:29 +0000 Subject: improve testing by verifying corner cases --- .../universal_dot_product_test.cpp | 71 +++++++++++++--------- 1 file changed, 43 insertions(+), 28 deletions(-) (limited to 'eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp') diff --git a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp index e1967f012cb..bf9aeead461 100644 --- a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp +++ b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -115,7 +116,17 @@ Optimize universal_only() { return Optimize::specific("universal_only", my_optimizer); } -void verify(const vespalib::string &expr, select_cell_type_t select_cell_type) { +Trinary tri(bool value) { + return value ? Trinary::True : Trinary::False; +} + +bool satisfies(bool actual, Trinary expect) { + return (expect == Trinary::Undefined) || (actual == (expect == Trinary::True)); +} + +void verify(const vespalib::string &expr, select_cell_type_t select_cell_type, + Trinary expect_forward, Trinary expect_distinct, Trinary expect_single) +{ ++verify_cnt; auto fun = Function::parse(expr); ASSERT_FALSE(fun->has_error()); @@ -134,10 +145,18 @@ void verify(const vespalib::string &expr, select_cell_type_t select_cell_type) { const ValueType &expected_type = node_types.get_type(fun->root()); ASSERT_FALSE(expected_type.is_error()); Stash stash; - size_t count = 0; + std::vector list; const TensorFunction &plain_fun = make_tensor_function(prod_factory, fun->root(), node_types, stash); - const TensorFunction &optimized = apply_tensor_function_optimizer(plain_fun, universal_only().optimizer, stash, &count); - ASSERT_GT(count, 0); + const TensorFunction &optimized = apply_tensor_function_optimizer(plain_fun, universal_only().optimizer, stash, + [&list](const auto &node){ + list.push_back(std::addressof(node)); + }); + ASSERT_EQ(list.size(), 1); + auto node = as(*list[0]); + ASSERT_TRUE(node); + EXPECT_TRUE(satisfies(node->forward(), expect_forward)); + EXPECT_TRUE(satisfies(node->distinct(), expect_distinct)); + EXPECT_TRUE(satisfies(node->single(), expect_single)); InterpretedFunction ifun(prod_factory, optimized); InterpretedFunction::Context ctx(ifun); const Value &actual = ifun.eval(ctx, params); @@ -152,7 +171,12 @@ void verify(const vespalib::string &expr, select_cell_type_t select_cell_type) { auto expected = eval_ref(*fun, select_cell_type); EXPECT_EQ(spec_from_value(actual), expected); } -void verify(const vespalib::string &expr) { verify(expr, always_double); } +void verify(const vespalib::string &expr) { + verify(expr, always_double, Trinary::Undefined, Trinary::Undefined, Trinary::Undefined); +} +void verify(const vespalib::string &expr, select_cell_type_t select_cell_type, bool forward, bool distinct, bool single) { + verify(expr, select_cell_type, tri(forward), tri(distinct), tri(single)); +} using cost_list_t = std::vector>; std::vector> benchmark_results; @@ -192,8 +216,9 @@ void benchmark(const vespalib::string &expr, std::vector list) { break; case Optimize::With::SPECIFIC: size_t count = 0; - optimized = std::addressof(apply_tensor_function_optimizer(plain_fun, optimize.optimizer, stash, &count)); - ASSERT_GT(count, 0); + optimized = std::addressof(apply_tensor_function_optimizer(plain_fun, optimize.optimizer, stash, + [&count](const auto &)noexcept{ ++count; })); + ASSERT_EQ(count, 1); break; } ASSERT_NE(optimized, nullptr); @@ -255,36 +280,26 @@ TEST(UniversalDotProductTest, test_select_cell_types) { } TEST(UniversalDotProductTest, universal_dot_product_works_for_various_cases) { - // forward, distinct, single - verify("reduce(2.0*3.0, sum)"); + // forward, distinct, single + verify("reduce(2.0*3.0, sum)", always_double, true, true, true); for (CellType lct: CellTypeUtils::list_types()) { for (CellType rct: CellTypeUtils::list_types()) { auto sel2 = select(lct, rct); - // !forward, !distinct, !single - verify("reduce(a4_1x8*a2_1x8,sum,a,x)", sel2); - - // !forward, !distinct, single - verify("reduce(a4_1x8*a2_1x8,sum,a)", sel2); - - // !forward, distinct, !single - verify("reduce(a4_1x8*a2_1x8,sum,x)", sel2); - - // forward, !distinct, !single - verify("reduce(a4_1x8*b2_1x8,sum,b,x)", sel2); - - // forward, !distinct, single - verify("reduce(a4_1x8*b2_1x8,sum,b)", sel2); - - // forward, distinct, !single - verify("reduce(a4_1x8*x8,sum,x)", sel2); + // forward, distinct, single + verify("reduce(a4_1x8*a2_1x8,sum,a,x)", sel2, false, false, false); + verify("reduce(a4_1x8*a2_1x8,sum,a)", sel2, false, false, true); + verify("reduce(a4_1x8*a2_1x8,sum,x)", sel2, false, true, false); + verify("reduce(a4_1x8*b2_1x8,sum,b,x)", sel2, true, false, false); + verify("reduce(a4_1x8*b2_1x8,sum,b)", sel2, true, false, true); + verify("reduce(a4_1x8*x8,sum,x)", sel2, true, true, false); } } // !forward, distinct, single - + // // This case is not possible since 'distinct' implies '!single' as // long as we reduce anything. The only expression allowed to - // reduce nothing is the scalar case. + // reduce nothing is the scalar case, which satisfies 'forward' } TEST(UniversalDotProductTest, universal_dot_product_works_with_complex_dimension_nesting) { -- cgit v1.2.3