diff options
author | Håvard Pettersen <havardpe@oath.com> | 2021-03-15 12:21:39 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2021-03-15 12:21:39 +0000 |
commit | a6957ce9c2716396786117ee322de79aa0e8d7cc (patch) | |
tree | d46dd767599114793b526ffaa728f15c5cd02f3b /eval/src | |
parent | 77926c9c0ca9128d9d75fd49e614dc62465ca45e (diff) |
test with all cell types
Diffstat (limited to 'eval/src')
-rw-r--r-- | eval/src/tests/instruction/dense_xw_product_function/dense_xw_product_function_test.cpp | 122 |
1 files changed, 54 insertions, 68 deletions
diff --git a/eval/src/tests/instruction/dense_xw_product_function/dense_xw_product_function_test.cpp b/eval/src/tests/instruction/dense_xw_product_function/dense_xw_product_function_test.cpp index f4fce7cb5f5..770ba337a2d 100644 --- a/eval/src/tests/instruction/dense_xw_product_function/dense_xw_product_function_test.cpp +++ b/eval/src/tests/instruction/dense_xw_product_function/dense_xw_product_function_test.cpp @@ -18,91 +18,74 @@ using namespace vespalib::eval::tensor_function; const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get(); -struct First { - bool value; - explicit First(bool value_in) : value(value_in) {} - operator bool() const { return value; } +GenSpec::seq_t lhs_seq = [] (size_t i) noexcept { return (3.0 + i) * 7.0; }; +GenSpec::seq_t rhs_seq = [] (size_t i) noexcept { return (5.0 + i) * 43.0; }; + +struct FunInfo { + using LookFor = DenseXWProductFunction; + size_t vec_size; + size_t res_size; + bool happy; + bool check(const LookFor &fun) const { + return ((fun.result_is_mutable()) && + (fun.vector_size() == vec_size) && + (fun.result_size() == res_size) && + (fun.common_inner() == happy)); + } }; -GenSpec::seq_t my_vec_seq = [] (size_t i) noexcept { return (3.0 + i) * 7.0; }; -GenSpec::seq_t my_mat_seq = [] (size_t i) noexcept { return (5.0 + i) * 43.0; }; - -void add_vector(EvalFixture::ParamRepo &repo, const char *d1, size_t s1) { - auto name = make_string("%s%zu", d1, s1); - auto layout = GenSpec().idx(d1, s1).seq(my_vec_seq); - repo.add(name, layout); - repo.add(name + "f", layout.cells_float()); -} - -void add_matrix(EvalFixture::ParamRepo &repo, const char *d1, size_t s1, const char *d2, size_t s2) { - auto name = make_string("%s%zu%s%zu", d1, s1, d2, s2); - auto layout = GenSpec().idx(d1, s1).idx(d2, s2).seq(my_mat_seq); - repo.add(name, layout); - repo.add(name + "f", layout.cells_float()); +void verify(const vespalib::string &expr, const std::vector<FunInfo> &fun_info, const std::vector<CellType> &with_cell_types) { + auto fun = Function::parse(expr); + ASSERT_EQUAL(fun->num_params(), 2u); + vespalib::string lhs_name = fun->param_name(0); + vespalib::string rhs_name = fun->param_name(1); + const auto lhs_spec = GenSpec::from_desc(lhs_name); + const auto rhs_spec = GenSpec::from_desc(rhs_name); + for (CellType lhs_ct: with_cell_types) { + for (CellType rhs_ct: with_cell_types) { + EvalFixture::ParamRepo param_repo; + param_repo.add(lhs_name, lhs_spec.cpy().cells(lhs_ct).seq(lhs_seq)); + param_repo.add(rhs_name, rhs_spec.cpy().cells(rhs_ct).seq(rhs_seq)); + EvalFixture slow_fixture(prod_factory, expr, param_repo, false); + EvalFixture fixture(prod_factory, expr, param_repo, true); + EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo)); + EXPECT_EQUAL(fixture.result(), slow_fixture.result()); + auto info = fixture.find_all<FunInfo::LookFor>(); + ASSERT_EQUAL(info.size(), fun_info.size()); + for (size_t i = 0; i < fun_info.size(); ++i) { + EXPECT_TRUE(fun_info[i].check(*info[i])); + } + } + } } -EvalFixture::ParamRepo make_params() { - EvalFixture::ParamRepo repo; - add_vector(repo, "y", 1); - add_vector(repo, "y", 3); - add_vector(repo, "y", 5); - add_vector(repo, "y", 16); - add_matrix(repo, "x", 1, "y", 1); - add_matrix(repo, "y", 1, "z", 1); - add_matrix(repo, "x", 2, "y", 3); - add_matrix(repo, "y", 3, "z", 2); - add_matrix(repo, "x", 2, "z", 3); - add_matrix(repo, "x", 8, "y", 5); - add_matrix(repo, "y", 5, "z", 8); - add_matrix(repo, "x", 5, "y", 16); - add_matrix(repo, "y", 16, "z", 5); - return repo; +void verify_not_optimized(const vespalib::string &expr) { + return verify(expr, {}, {CellType::FLOAT}); } -EvalFixture::ParamRepo param_repo = make_params(); void verify_optimized(const vespalib::string &expr, size_t vec_size, size_t res_size, bool happy) { - EvalFixture slow_fixture(prod_factory, expr, param_repo, false); - EvalFixture fixture(prod_factory, expr, param_repo, true); - EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo)); - EXPECT_EQUAL(fixture.result(), slow_fixture.result()); - auto info = fixture.find_all<DenseXWProductFunction>(); - ASSERT_EQUAL(info.size(), 1u); - EXPECT_TRUE(info[0]->result_is_mutable()); - EXPECT_EQUAL(info[0]->vector_size(), vec_size); - EXPECT_EQUAL(info[0]->result_size(), res_size); - EXPECT_EQUAL(info[0]->common_inner(), happy); + return verify(expr, {{vec_size, res_size, happy}}, CellTypeUtils::list_types()); } -vespalib::string make_expr(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common, - bool float_a, bool float_b) -{ - return make_string("reduce(%s%s*%s%s,sum,%s)", a.c_str(), float_a ? "f" : "", b.c_str(), float_b ? "f" : "", common.c_str()); +vespalib::string make_expr(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common) { + return make_string("reduce(%s*%s,sum,%s)", a.c_str(), b.c_str(), common.c_str()); } void verify_optimized_multi(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common, - size_t vec_size, size_t res_size, bool happy, First first = First(true)) + size_t vec_size, size_t res_size, bool happy) { - for (bool float_a: {false, true}) { - for (bool float_b: {false, true}) { - auto expr = make_expr(a, b, common, float_a, float_b); - TEST_STATE(expr.c_str()); - TEST_DO(verify_optimized(expr, vec_size, res_size, happy)); - } + { + auto expr = make_expr(a, b, common); + TEST_STATE(expr.c_str()); + TEST_DO(verify_optimized(expr, vec_size, res_size, happy)); } - if (first) { - TEST_DO(verify_optimized_multi(b, a, common, vec_size, res_size, happy, First(false))); + { + auto expr = make_expr(b, a, common); + TEST_STATE(expr.c_str()); + TEST_DO(verify_optimized(expr, vec_size, res_size, happy)); } } -void verify_not_optimized(const vespalib::string &expr) { - EvalFixture slow_fixture(prod_factory, expr, param_repo, false); - EvalFixture fixture(prod_factory, expr, param_repo, true); - EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo)); - EXPECT_EQUAL(fixture.result(), slow_fixture.result()); - auto info = fixture.find_all<DenseXWProductFunction>(); - EXPECT_TRUE(info.empty()); -} - TEST("require that xw product gives same results as reference join/reduce") { // 1 -> 1 happy/unhappy TEST_DO(verify_optimized_multi("y1", "x1y1", "y", 1, 1, true)); @@ -136,6 +119,9 @@ TEST("require that expressions similar to xw product are not optimized") { } TEST("require that xw product can be debug dumped") { + EvalFixture::ParamRepo param_repo; + param_repo.add("y5", GenSpec::from_desc("y5")); + param_repo.add("x8y5", GenSpec::from_desc("x8y5")); EvalFixture fixture(prod_factory, "reduce(y5*x8y5,sum,y)", param_repo, true); auto info = fixture.find_all<DenseXWProductFunction>(); ASSERT_EQUAL(info.size(), 1u); |