aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2021-03-15 12:21:39 +0000
committerHåvard Pettersen <havardpe@oath.com>2021-03-15 12:21:39 +0000
commita6957ce9c2716396786117ee322de79aa0e8d7cc (patch)
treed46dd767599114793b526ffaa728f15c5cd02f3b /eval/src
parent77926c9c0ca9128d9d75fd49e614dc62465ca45e (diff)
test with all cell types
Diffstat (limited to 'eval/src')
-rw-r--r--eval/src/tests/instruction/dense_xw_product_function/dense_xw_product_function_test.cpp122
1 files changed, 54 insertions, 68 deletions
diff --git a/eval/src/tests/instruction/dense_xw_product_function/dense_xw_product_function_test.cpp b/eval/src/tests/instruction/dense_xw_product_function/dense_xw_product_function_test.cpp
index f4fce7cb5f5..770ba337a2d 100644
--- a/eval/src/tests/instruction/dense_xw_product_function/dense_xw_product_function_test.cpp
+++ b/eval/src/tests/instruction/dense_xw_product_function/dense_xw_product_function_test.cpp
@@ -18,91 +18,74 @@ using namespace vespalib::eval::tensor_function;
const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();
-struct First {
- bool value;
- explicit First(bool value_in) : value(value_in) {}
- operator bool() const { return value; }
+GenSpec::seq_t lhs_seq = [] (size_t i) noexcept { return (3.0 + i) * 7.0; };
+GenSpec::seq_t rhs_seq = [] (size_t i) noexcept { return (5.0 + i) * 43.0; };
+
+struct FunInfo {
+ using LookFor = DenseXWProductFunction;
+ size_t vec_size;
+ size_t res_size;
+ bool happy;
+ bool check(const LookFor &fun) const {
+ return ((fun.result_is_mutable()) &&
+ (fun.vector_size() == vec_size) &&
+ (fun.result_size() == res_size) &&
+ (fun.common_inner() == happy));
+ }
};
-GenSpec::seq_t my_vec_seq = [] (size_t i) noexcept { return (3.0 + i) * 7.0; };
-GenSpec::seq_t my_mat_seq = [] (size_t i) noexcept { return (5.0 + i) * 43.0; };
-
-void add_vector(EvalFixture::ParamRepo &repo, const char *d1, size_t s1) {
- auto name = make_string("%s%zu", d1, s1);
- auto layout = GenSpec().idx(d1, s1).seq(my_vec_seq);
- repo.add(name, layout);
- repo.add(name + "f", layout.cells_float());
-}
-
-void add_matrix(EvalFixture::ParamRepo &repo, const char *d1, size_t s1, const char *d2, size_t s2) {
- auto name = make_string("%s%zu%s%zu", d1, s1, d2, s2);
- auto layout = GenSpec().idx(d1, s1).idx(d2, s2).seq(my_mat_seq);
- repo.add(name, layout);
- repo.add(name + "f", layout.cells_float());
+void verify(const vespalib::string &expr, const std::vector<FunInfo> &fun_info, const std::vector<CellType> &with_cell_types) {
+ auto fun = Function::parse(expr);
+ ASSERT_EQUAL(fun->num_params(), 2u);
+ vespalib::string lhs_name = fun->param_name(0);
+ vespalib::string rhs_name = fun->param_name(1);
+ const auto lhs_spec = GenSpec::from_desc(lhs_name);
+ const auto rhs_spec = GenSpec::from_desc(rhs_name);
+ for (CellType lhs_ct: with_cell_types) {
+ for (CellType rhs_ct: with_cell_types) {
+ EvalFixture::ParamRepo param_repo;
+ param_repo.add(lhs_name, lhs_spec.cpy().cells(lhs_ct).seq(lhs_seq));
+ param_repo.add(rhs_name, rhs_spec.cpy().cells(rhs_ct).seq(rhs_seq));
+ EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
+ EvalFixture fixture(prod_factory, expr, param_repo, true);
+ EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQUAL(fixture.result(), slow_fixture.result());
+ auto info = fixture.find_all<FunInfo::LookFor>();
+ ASSERT_EQUAL(info.size(), fun_info.size());
+ for (size_t i = 0; i < fun_info.size(); ++i) {
+ EXPECT_TRUE(fun_info[i].check(*info[i]));
+ }
+ }
+ }
}
-EvalFixture::ParamRepo make_params() {
- EvalFixture::ParamRepo repo;
- add_vector(repo, "y", 1);
- add_vector(repo, "y", 3);
- add_vector(repo, "y", 5);
- add_vector(repo, "y", 16);
- add_matrix(repo, "x", 1, "y", 1);
- add_matrix(repo, "y", 1, "z", 1);
- add_matrix(repo, "x", 2, "y", 3);
- add_matrix(repo, "y", 3, "z", 2);
- add_matrix(repo, "x", 2, "z", 3);
- add_matrix(repo, "x", 8, "y", 5);
- add_matrix(repo, "y", 5, "z", 8);
- add_matrix(repo, "x", 5, "y", 16);
- add_matrix(repo, "y", 16, "z", 5);
- return repo;
+void verify_not_optimized(const vespalib::string &expr) {
+ return verify(expr, {}, {CellType::FLOAT});
}
-EvalFixture::ParamRepo param_repo = make_params();
void verify_optimized(const vespalib::string &expr, size_t vec_size, size_t res_size, bool happy) {
- EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
- EvalFixture fixture(prod_factory, expr, param_repo, true);
- EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
- EXPECT_EQUAL(fixture.result(), slow_fixture.result());
- auto info = fixture.find_all<DenseXWProductFunction>();
- ASSERT_EQUAL(info.size(), 1u);
- EXPECT_TRUE(info[0]->result_is_mutable());
- EXPECT_EQUAL(info[0]->vector_size(), vec_size);
- EXPECT_EQUAL(info[0]->result_size(), res_size);
- EXPECT_EQUAL(info[0]->common_inner(), happy);
+ return verify(expr, {{vec_size, res_size, happy}}, CellTypeUtils::list_types());
}
-vespalib::string make_expr(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common,
- bool float_a, bool float_b)
-{
- return make_string("reduce(%s%s*%s%s,sum,%s)", a.c_str(), float_a ? "f" : "", b.c_str(), float_b ? "f" : "", common.c_str());
+vespalib::string make_expr(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common) {
+ return make_string("reduce(%s*%s,sum,%s)", a.c_str(), b.c_str(), common.c_str());
}
void verify_optimized_multi(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common,
- size_t vec_size, size_t res_size, bool happy, First first = First(true))
+ size_t vec_size, size_t res_size, bool happy)
{
- for (bool float_a: {false, true}) {
- for (bool float_b: {false, true}) {
- auto expr = make_expr(a, b, common, float_a, float_b);
- TEST_STATE(expr.c_str());
- TEST_DO(verify_optimized(expr, vec_size, res_size, happy));
- }
+ {
+ auto expr = make_expr(a, b, common);
+ TEST_STATE(expr.c_str());
+ TEST_DO(verify_optimized(expr, vec_size, res_size, happy));
}
- if (first) {
- TEST_DO(verify_optimized_multi(b, a, common, vec_size, res_size, happy, First(false)));
+ {
+ auto expr = make_expr(b, a, common);
+ TEST_STATE(expr.c_str());
+ TEST_DO(verify_optimized(expr, vec_size, res_size, happy));
}
}
-void verify_not_optimized(const vespalib::string &expr) {
- EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
- EvalFixture fixture(prod_factory, expr, param_repo, true);
- EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
- EXPECT_EQUAL(fixture.result(), slow_fixture.result());
- auto info = fixture.find_all<DenseXWProductFunction>();
- EXPECT_TRUE(info.empty());
-}
-
TEST("require that xw product gives same results as reference join/reduce") {
// 1 -> 1 happy/unhappy
TEST_DO(verify_optimized_multi("y1", "x1y1", "y", 1, 1, true));
@@ -136,6 +119,9 @@ TEST("require that expressions similar to xw product are not optimized") {
}
TEST("require that xw product can be debug dumped") {
+ EvalFixture::ParamRepo param_repo;
+ param_repo.add("y5", GenSpec::from_desc("y5"));
+ param_repo.add("x8y5", GenSpec::from_desc("x8y5"));
EvalFixture fixture(prod_factory, "reduce(y5*x8y5,sum,y)", param_repo, true);
auto info = fixture.find_all<DenseXWProductFunction>();
ASSERT_EQUAL(info.size(), 1u);