diff options
author | Arne Juul <arnej@verizonmedia.com> | 2021-03-17 12:53:29 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2021-03-17 14:07:56 +0000 |
commit | 12e18f4ae9bb756066f74d7076253d5ad3803987 (patch) | |
tree | 03a10e8ef8a1fa53a92139be5205ab2b7b72126c /eval | |
parent | a02f7e628be4196a83f7ab0e379c74a223eefc93 (diff) |
rewrite DenseMatMulFunction test using EvalFixture::verify<FunInfo>
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/tests/instruction/dense_matmul_function/dense_matmul_function_test.cpp | 111 |
1 files changed, 45 insertions, 66 deletions
diff --git a/eval/src/tests/instruction/dense_matmul_function/dense_matmul_function_test.cpp b/eval/src/tests/instruction/dense_matmul_function/dense_matmul_function_test.cpp index bcc5e6c7b5d..438424fdf4e 100644 --- a/eval/src/tests/instruction/dense_matmul_function/dense_matmul_function_test.cpp +++ b/eval/src/tests/instruction/dense_matmul_function/dense_matmul_function_test.cpp @@ -8,7 +8,6 @@ #include <vespa/eval/instruction/dense_matmul_function.h> #include <vespa/vespalib/testkit/test_kit.h> #include <vespa/vespalib/util/stash.h> -#include <vespa/vespalib/util/stringfmt.h> using namespace vespalib; using namespace vespalib::eval; @@ -17,50 +16,45 @@ using namespace vespalib::eval::tensor_function; const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get(); -EvalFixture::ParamRepo make_params() { - return EvalFixture::ParamRepo() - .add_variants("a2d3", GenSpec().idx("a", 2).idx("d", 3)) // inner/inner - .add_variants("a2b5", GenSpec().idx("a", 2).idx("b", 5)) // inner/outer - .add_variants("b5c2", GenSpec().idx("b", 5).idx("c", 2)) // outer/outer - .add_variants("a2c3", GenSpec().idx("a", 2).idx("c", 3)) // not matching - //------------------------------------------ - .add_variants("b5d3", GenSpec().idx("b", 5).idx("d", 3)); // fixed param -} -EvalFixture::ParamRepo param_repo = make_params(); +struct FunInfo { + using LookFor = DenseMatMulFunction; + size_t lhs_size; + size_t common_size; + size_t rhs_size; + bool lhs_inner; + bool rhs_inner; + void verify(const LookFor &fun) const { + EXPECT_TRUE(fun.result_is_mutable()); + EXPECT_EQUAL(fun.lhs_size(), lhs_size); + EXPECT_EQUAL(fun.common_size(), common_size); + EXPECT_EQUAL(fun.rhs_size(), rhs_size); + EXPECT_EQUAL(fun.lhs_common_inner(), lhs_inner); + EXPECT_EQUAL(fun.rhs_common_inner(), rhs_inner); + } +}; -void verify_optimized(const vespalib::string &expr, - size_t lhs_size, size_t common_size, size_t rhs_size, - bool lhs_inner, bool rhs_inner) -{ - EvalFixture slow_fixture(prod_factory, expr, param_repo, false); - EvalFixture fixture(prod_factory, expr, param_repo, true); - EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo)); - EXPECT_EQUAL(fixture.result(), slow_fixture.result()); - auto info = fixture.find_all<DenseMatMulFunction>(); - ASSERT_EQUAL(info.size(), 1u); - EXPECT_TRUE(info[0]->result_is_mutable()); - EXPECT_EQUAL(info[0]->lhs_size(), lhs_size); - EXPECT_EQUAL(info[0]->common_size(), common_size); - EXPECT_EQUAL(info[0]->rhs_size(), rhs_size); - EXPECT_EQUAL(info[0]->lhs_common_inner(), lhs_inner); - EXPECT_EQUAL(info[0]->rhs_common_inner(), rhs_inner); +void verify_optimized(const vespalib::string &expr, FunInfo details) { + TEST_STATE(expr.c_str()); + CellTypeSpace all_types(CellTypeUtils::list_types(), 2); + EvalFixture::verify<FunInfo>(expr, {details}, all_types); } void verify_not_optimized(const vespalib::string &expr) { - EvalFixture slow_fixture(prod_factory, expr, param_repo, false); - EvalFixture fixture(prod_factory, expr, param_repo, true); - EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo)); - EXPECT_EQUAL(fixture.result(), slow_fixture.result()); - auto info = fixture.find_all<DenseMatMulFunction>(); - EXPECT_TRUE(info.empty()); + TEST_STATE(expr.c_str()); + CellTypeSpace just_double({CellType::DOUBLE}, 2); + EvalFixture::verify<FunInfo>(expr, {}, just_double); } TEST("require that matmul can be optimized") { - TEST_DO(verify_optimized("reduce(a2d3*b5d3,sum,d)", 2, 3, 5, true, true)); + FunInfo details = { .lhs_size = 2, .common_size = 3, .rhs_size = 5, + .lhs_inner = true, .rhs_inner = true }; + TEST_DO(verify_optimized("reduce(a2d3*b5d3,sum,d)", details)); } TEST("require that matmul with lambda can be optimized") { - TEST_DO(verify_optimized("reduce(join(a2d3,b5d3,f(x,y)(x*y)),sum,d)", 2, 3, 5, true, true)); + FunInfo details = { .lhs_size = 2, .common_size = 3, .rhs_size = 5, + .lhs_inner = true, .rhs_inner = true }; + TEST_DO(verify_optimized("reduce(join(a2d3,b5d3,f(x,y)(x*y)),sum,d)", details)); } TEST("require that expressions similar to matmul are not optimized") { @@ -77,49 +71,34 @@ TEST("require that expressions similar to matmul are not optimized") { TEST_DO(verify_not_optimized("reduce(a2c3*b5d3,sum,c)")); } -TEST("require that xw product can be debug dumped") { - EvalFixture fixture(prod_factory, "reduce(a2d3*b5d3,sum,d)", param_repo, true); +TEST("require that MatMul can be debug dumped") { + EvalFixture fixture(prod_factory, "reduce(x*y,sum,d)", EvalFixture::ParamRepo() + .add("x", GenSpec::from_desc("a2d3")) + .add("y", GenSpec::from_desc("b5d3")), true); auto info = fixture.find_all<DenseMatMulFunction>(); ASSERT_EQUAL(info.size(), 1u); fprintf(stderr, "%s\n", info[0]->as_string().c_str()); } -vespalib::string make_expr(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common, - bool float_a, bool float_b) -{ - return make_string("reduce(%s%s*%s%s,sum,%s)", a.c_str(), float_a ? "_f" : "", b.c_str(), float_b ? "_f" : "", common.c_str()); -} - -void verify_optimized_multi(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common, - size_t lhs_size, size_t common_size, size_t rhs_size, - bool lhs_inner, bool rhs_inner) -{ - for (bool float_a: {false, true}) { - for (bool float_b: {false, true}) { - { - auto expr = make_expr(a, b, common, float_a, float_b); - TEST_STATE(expr.c_str()); - TEST_DO(verify_optimized(expr, lhs_size, common_size, rhs_size, lhs_inner, rhs_inner)); - } - { - auto expr = make_expr(b, a, common, float_b, float_a); - TEST_STATE(expr.c_str()); - TEST_DO(verify_optimized(expr, lhs_size, common_size, rhs_size, lhs_inner, rhs_inner)); - } - } - } -} - TEST("require that matmul inner/inner works correctly") { - TEST_DO(verify_optimized_multi("a2d3", "b5d3", "d", 2, 3, 5, true, true)); + FunInfo details = { .lhs_size = 2, .common_size = 3, .rhs_size = 5, + .lhs_inner = true, .rhs_inner = true }; + TEST_DO(verify_optimized("reduce(a2d3*b5d3,sum,d)", details)); + TEST_DO(verify_optimized("reduce(b5d3*a2d3,sum,d)", details)); } TEST("require that matmul inner/outer works correctly") { - TEST_DO(verify_optimized_multi("a2b5", "b5d3", "b", 2, 5, 3, true, false)); + FunInfo details = { .lhs_size = 2, .common_size = 5, .rhs_size = 3, + .lhs_inner = true, .rhs_inner = false }; + TEST_DO(verify_optimized("reduce(a2b5*b5d3,sum,b)", details)); + TEST_DO(verify_optimized("reduce(b5d3*a2b5,sum,b)", details)); } TEST("require that matmul outer/outer works correctly") { - TEST_DO(verify_optimized_multi("b5c2", "b5d3", "b", 2, 5, 3, false, false)); + FunInfo details = { .lhs_size = 2, .common_size = 5, .rhs_size = 3, + .lhs_inner = false, .rhs_inner = false }; + TEST_DO(verify_optimized("reduce(b5c2*b5d3,sum,b)", details)); + TEST_DO(verify_optimized("reduce(b5d3*b5c2,sum,b)", details)); } TEST_MAIN() { TEST_RUN_ALL(); } |