diff options
author | Håvard Pettersen <havardpe@yahooinc.com> | 2023-09-12 12:15:29 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@yahooinc.com> | 2023-09-12 12:15:29 +0000 |
commit | 7a3a53ec4e11eff07cf425f368e5faed477e5fb5 (patch) | |
tree | 76223812496ca7160027ad3ce169cad6bccb12b7 /eval/src | |
parent | 5e335474fccb3dbfe0e631e72648d3ae8b1ff703 (diff) |
improve testing by verifying corner cases
Diffstat (limited to 'eval/src')
5 files changed, 78 insertions, 36 deletions
diff --git a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp index e1967f012cb..bf9aeead461 100644 --- a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp +++ b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp @@ -15,6 +15,7 @@ #include <vespa/vespalib/util/benchmark_timer.h> #include <vespa/vespalib/util/classname.h> #include <vespa/vespalib/util/stringfmt.h> +#include <vespa/vespalib/util/trinary.h> #include <vespa/vespalib/gtest/gtest.h> #include <optional> @@ -115,7 +116,17 @@ Optimize universal_only() { return Optimize::specific("universal_only", my_optimizer); } -void verify(const vespalib::string &expr, select_cell_type_t select_cell_type) { +Trinary tri(bool value) { + return value ? Trinary::True : Trinary::False; +} + +bool satisfies(bool actual, Trinary expect) { + return (expect == Trinary::Undefined) || (actual == (expect == Trinary::True)); +} + +void verify(const vespalib::string &expr, select_cell_type_t select_cell_type, + Trinary expect_forward, Trinary expect_distinct, Trinary expect_single) +{ ++verify_cnt; auto fun = Function::parse(expr); ASSERT_FALSE(fun->has_error()); @@ -134,10 +145,18 @@ void verify(const vespalib::string &expr, select_cell_type_t select_cell_type) { const ValueType &expected_type = node_types.get_type(fun->root()); ASSERT_FALSE(expected_type.is_error()); Stash stash; - size_t count = 0; + std::vector<const TensorFunction *> list; const TensorFunction &plain_fun = make_tensor_function(prod_factory, fun->root(), node_types, stash); - const TensorFunction &optimized = apply_tensor_function_optimizer(plain_fun, universal_only().optimizer, stash, &count); - ASSERT_GT(count, 0); + const TensorFunction &optimized = apply_tensor_function_optimizer(plain_fun, universal_only().optimizer, stash, + [&list](const auto &node){ + list.push_back(std::addressof(node)); + }); + ASSERT_EQ(list.size(), 1); + auto node = as<UniversalDotProduct>(*list[0]); + ASSERT_TRUE(node); + EXPECT_TRUE(satisfies(node->forward(), expect_forward)); + EXPECT_TRUE(satisfies(node->distinct(), expect_distinct)); + EXPECT_TRUE(satisfies(node->single(), expect_single)); InterpretedFunction ifun(prod_factory, optimized); InterpretedFunction::Context ctx(ifun); const Value &actual = ifun.eval(ctx, params); @@ -152,7 +171,12 @@ void verify(const vespalib::string &expr, select_cell_type_t select_cell_type) { auto expected = eval_ref(*fun, select_cell_type); EXPECT_EQ(spec_from_value(actual), expected); } -void verify(const vespalib::string &expr) { verify(expr, always_double); } +void verify(const vespalib::string &expr) { + verify(expr, always_double, Trinary::Undefined, Trinary::Undefined, Trinary::Undefined); +} +void verify(const vespalib::string &expr, select_cell_type_t select_cell_type, bool forward, bool distinct, bool single) { + verify(expr, select_cell_type, tri(forward), tri(distinct), tri(single)); +} using cost_list_t = std::vector<std::pair<vespalib::string,double>>; std::vector<std::pair<vespalib::string,cost_list_t>> benchmark_results; @@ -192,8 +216,9 @@ void benchmark(const vespalib::string &expr, std::vector<Optimize> list) { break; case Optimize::With::SPECIFIC: size_t count = 0; - optimized = std::addressof(apply_tensor_function_optimizer(plain_fun, optimize.optimizer, stash, &count)); - ASSERT_GT(count, 0); + optimized = std::addressof(apply_tensor_function_optimizer(plain_fun, optimize.optimizer, stash, + [&count](const auto &)noexcept{ ++count; })); + ASSERT_EQ(count, 1); break; } ASSERT_NE(optimized, nullptr); @@ -255,36 +280,26 @@ TEST(UniversalDotProductTest, test_select_cell_types) { } TEST(UniversalDotProductTest, universal_dot_product_works_for_various_cases) { - // forward, distinct, single - verify("reduce(2.0*3.0, sum)"); + // forward, distinct, single + verify("reduce(2.0*3.0, sum)", always_double, true, true, true); for (CellType lct: CellTypeUtils::list_types()) { for (CellType rct: CellTypeUtils::list_types()) { auto sel2 = select(lct, rct); - // !forward, !distinct, !single - verify("reduce(a4_1x8*a2_1x8,sum,a,x)", sel2); - - // !forward, !distinct, single - verify("reduce(a4_1x8*a2_1x8,sum,a)", sel2); - - // !forward, distinct, !single - verify("reduce(a4_1x8*a2_1x8,sum,x)", sel2); - - // forward, !distinct, !single - verify("reduce(a4_1x8*b2_1x8,sum,b,x)", sel2); - - // forward, !distinct, single - verify("reduce(a4_1x8*b2_1x8,sum,b)", sel2); - - // forward, distinct, !single - verify("reduce(a4_1x8*x8,sum,x)", sel2); + // forward, distinct, single + verify("reduce(a4_1x8*a2_1x8,sum,a,x)", sel2, false, false, false); + verify("reduce(a4_1x8*a2_1x8,sum,a)", sel2, false, false, true); + verify("reduce(a4_1x8*a2_1x8,sum,x)", sel2, false, true, false); + verify("reduce(a4_1x8*b2_1x8,sum,b,x)", sel2, true, false, false); + verify("reduce(a4_1x8*b2_1x8,sum,b)", sel2, true, false, true); + verify("reduce(a4_1x8*x8,sum,x)", sel2, true, true, false); } } // !forward, distinct, single - + // // This case is not possible since 'distinct' implies '!single' as // long as we reduce anything. The only expression allowed to - // reduce nothing is the scalar case. + // reduce nothing is the scalar case, which satisfies 'forward' } TEST(UniversalDotProductTest, universal_dot_product_works_with_complex_dimension_nesting) { diff --git a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp index 4013021aaa4..7255d308c81 100644 --- a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp +++ b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp @@ -140,7 +140,7 @@ const TensorFunction &optimize_tensor_function(const ValueBuilderFactory &factor return optimize_tensor_function(factory, function, stash, OptimizeTensorFunctionOptions()); } -const TensorFunction &apply_tensor_function_optimizer(const TensorFunction &function, tensor_function_optimizer optimizer, Stash &stash, size_t *count) { +const TensorFunction &apply_tensor_function_optimizer(const TensorFunction &function, tensor_function_optimizer optimizer, Stash &stash, tensor_function_listener listener) { Child root(function); run_optimize_pass(root, [&](const Child &child) { @@ -148,9 +148,7 @@ const TensorFunction &apply_tensor_function_optimizer(const TensorFunction &func const TensorFunction &child_after = optimizer(child_before, stash); if (std::addressof(child_after) != std::addressof(child_before)) { child.set(child_after); - if (count != nullptr) { - ++(*count); - } + listener(child_after); } }); return root.get(); diff --git a/eval/src/vespa/eval/eval/optimize_tensor_function.h b/eval/src/vespa/eval/eval/optimize_tensor_function.h index 4a5945860e7..fd8c9b33d8c 100644 --- a/eval/src/vespa/eval/eval/optimize_tensor_function.h +++ b/eval/src/vespa/eval/eval/optimize_tensor_function.h @@ -22,6 +22,8 @@ const TensorFunction &optimize_tensor_function(const ValueBuilderFactory &factor const TensorFunction &optimize_tensor_function(const ValueBuilderFactory &factory, const TensorFunction &function, Stash &stash); using tensor_function_optimizer = std::function<const TensorFunction &(const TensorFunction &expr, Stash &stash)>; -const TensorFunction &apply_tensor_function_optimizer(const TensorFunction &function, tensor_function_optimizer optimizer, Stash &stash, size_t *count = nullptr); +using tensor_function_listener = std::function<void(const TensorFunction &expr)>; +const TensorFunction &apply_tensor_function_optimizer(const TensorFunction &function, tensor_function_optimizer optimizer, Stash &stash, + tensor_function_listener = [](const TensorFunction &)noexcept{}); } // namespace vespalib::eval diff --git a/eval/src/vespa/eval/instruction/universal_dot_product.cpp b/eval/src/vespa/eval/instruction/universal_dot_product.cpp index 414a54f09a8..e023609114a 100644 --- a/eval/src/vespa/eval/instruction/universal_dot_product.cpp +++ b/eval/src/vespa/eval/instruction/universal_dot_product.cpp @@ -40,6 +40,9 @@ struct UniversalDotProductParam { dense_plan.res_stride.pop_back(); } } + bool forward() const { return sparse_plan.maybe_forward_lhs_index(); } + bool distinct() const { return sparse_plan.is_distinct() && dense_plan.is_distinct(); } + bool single() const { return vector_size == 1; } }; template <typename OCT> @@ -204,12 +207,33 @@ UniversalDotProduct::compile_self(const ValueBuilderFactory &, Stash &stash) con auto op = typify_invoke<6,MyTypify,SelectUniversalDotProduct>(lhs().result_type().cell_meta(), rhs().result_type().cell_meta(), result_type().cell_meta().is_scalar, - param.sparse_plan.maybe_forward_lhs_index(), - param.sparse_plan.is_distinct() && param.dense_plan.is_distinct(), - param.vector_size == 1); + param.forward(), + param.distinct(), + param.single()); return InterpretedFunction::Instruction(op, wrap_param<UniversalDotProductParam>(param)); } +bool +UniversalDotProduct::forward() const +{ + UniversalDotProductParam param(result_type(), lhs().result_type(), rhs().result_type()); + return param.forward(); +} + +bool +UniversalDotProduct::distinct() const +{ + UniversalDotProductParam param(result_type(), lhs().result_type(), rhs().result_type()); + return param.distinct(); +} + +bool +UniversalDotProduct::single() const +{ + UniversalDotProductParam param(result_type(), lhs().result_type(), rhs().result_type()); + return param.single(); +} + const TensorFunction & UniversalDotProduct::optimize(const TensorFunction &expr, Stash &stash, bool force) { diff --git a/eval/src/vespa/eval/instruction/universal_dot_product.h b/eval/src/vespa/eval/instruction/universal_dot_product.h index 40fd109cc73..2572ab47c65 100644 --- a/eval/src/vespa/eval/instruction/universal_dot_product.h +++ b/eval/src/vespa/eval/instruction/universal_dot_product.h @@ -19,6 +19,9 @@ public: UniversalDotProduct(const ValueType &res_type, const TensorFunction &lhs, const TensorFunction &rhs); InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override; bool result_is_mutable() const override { return true; } + bool forward() const; + bool distinct() const; + bool single() const; static const TensorFunction &optimize(const TensorFunction &expr, Stash &stash, bool force); }; |