summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2021-03-17 15:04:45 +0100
committerGitHub <noreply@github.com>2021-03-17 15:04:45 +0100
commit451495718485047eebd66fd37bf53b2c27a19f10 (patch)
treebe17c5c8870dde23ad622c5999eb8d9abb8aaae4 /eval
parent48ff630a4bc00a32dd2897b441d449f646c1e927 (diff)
parentedfce5980b08620274b1b7a2bc86789e8b74a373 (diff)
Merge pull request #17007 from vespa-engine/arnej/update_dense_multi_matmul
rewrite DenseMultiMatMulFunction test
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp136
1 files changed, 52 insertions, 84 deletions
diff --git a/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp b/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp
index 9138668b8c4..9fa25466f4a 100644
--- a/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp
+++ b/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp
@@ -17,70 +17,57 @@ using namespace vespalib::eval::tensor_function;
const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();
-GenSpec G(std::vector<std::pair<const char *, size_t>> dims) {
- GenSpec result;
- for (const auto & dim : dims) {
- result.idx(dim.first, dim.second);
+struct FunInfo {
+ using LookFor = DenseMultiMatMulFunction;
+ size_t lhs_size;
+ size_t common_size;
+ size_t rhs_size;
+ size_t matmul_cnt;
+ bool lhs_inner;
+ bool rhs_inner;
+ void verify(const LookFor &fun) const {
+ EXPECT_TRUE(fun.result_is_mutable());
+ EXPECT_EQUAL(fun.lhs_size(), lhs_size);
+ EXPECT_EQUAL(fun.common_size(), common_size);
+ EXPECT_EQUAL(fun.rhs_size(), rhs_size);
+ EXPECT_EQUAL(fun.matmul_cnt(), matmul_cnt);
+ EXPECT_EQUAL(fun.lhs_common_inner(), lhs_inner);
+ EXPECT_EQUAL(fun.rhs_common_inner(), rhs_inner);
}
- return result;
-}
-
-EvalFixture::ParamRepo make_params() {
- return EvalFixture::ParamRepo()
- .add_variants("A2B1C3a2d3", G({{"A",2}, {"B",1}, {"C",3}, {"a",2}, {"d",3}})) // inner/inner
- .add_variants("A2B1C3D1a2c1d3e1", G({{"A",2}, {"B",1}, {"C",3}, {"D",1}, {"a",2}, {"c",1}, {"d",3}, {"e",1}}))// inner/inner, extra dims
- .add_variants("B1C3a2d3", G({{"B",1}, {"C",3}, {"a",2}, {"d",3}})) // inner/inner, missing A
- .add_variants("A1a2d3", G({{"A",1}, {"a",2}, {"d",3}})) // inner/inner, single mat
- .add_variants("A2D3a2b1c3", G({{"A",2}, {"D",3}, {"a",2}, {"b",1}, {"c",3}})) // inner/inner, inverted
- .add_variants("A2B1C3a2b5", G({{"A",2}, {"B",1}, {"C",3}, {"a",2}, {"b",5}})) // inner/outer
- .add_variants("A2B1C3b5c2", G({{"A",2}, {"B",1}, {"C",3}, {"b",5}, {"c",2}})) // outer/outer
- .add_variants("A2B1C3a2c3", G({{"A",2}, {"B",1}, {"C",3}, {"a",2}, {"c",3}})) // not matching
- //----------------------------------------------------------------------------------------
- .add_variants("A2B1C3b5d3", G({{"A",2}, {"B",1}, {"C",3}, {"b",5}, {"d",3}})) // fixed param
- .add_variants("B1C3b5d3", G({{"B",1}, {"C",3}, {"b",5}, {"d",3}})) // fixed param, missing A
- .add_variants("A1b5d3", G({{"A",1}, {"b",5}, {"d",3}})) // fixed param, single mat
- .add_variants("B5D3a2b1c3", G({{"B",5}, {"D",3}, {"a",2}, {"b",1}, {"c",3}})); // fixed param, inverted
-}
-EvalFixture::ParamRepo param_repo = make_params();
+};
-void verify_optimized(const vespalib::string &expr,
- size_t lhs_size, size_t common_size, size_t rhs_size, size_t matmul_cnt,
- bool lhs_inner, bool rhs_inner)
+void verify_optimized(const vespalib::string &expr, const FunInfo &details)
{
- EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
- EvalFixture fixture(prod_factory, expr, param_repo, true);
- EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
- EXPECT_EQUAL(fixture.result(), slow_fixture.result());
- auto info = fixture.find_all<DenseMultiMatMulFunction>();
- ASSERT_EQUAL(info.size(), 1u);
- EXPECT_TRUE(info[0]->result_is_mutable());
- EXPECT_EQUAL(info[0]->lhs_size(), lhs_size);
- EXPECT_EQUAL(info[0]->common_size(), common_size);
- EXPECT_EQUAL(info[0]->rhs_size(), rhs_size);
- EXPECT_EQUAL(info[0]->matmul_cnt(), matmul_cnt);
- EXPECT_EQUAL(info[0]->lhs_common_inner(), lhs_inner);
- EXPECT_EQUAL(info[0]->rhs_common_inner(), rhs_inner);
+ TEST_STATE(expr.c_str());
+ auto same_types = CellTypeSpace(CellTypeUtils::list_types(), 2).same();
+ EvalFixture::verify<FunInfo>(expr, {details}, same_types);
+ auto diff_types = CellTypeSpace(CellTypeUtils::list_types(), 2).different();
+ EvalFixture::verify<FunInfo>(expr, {}, diff_types);
}
void verify_not_optimized(const vespalib::string &expr) {
- EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
- EvalFixture fixture(prod_factory, expr, param_repo, true);
- EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
- EXPECT_EQUAL(fixture.result(), slow_fixture.result());
- auto info = fixture.find_all<DenseMultiMatMulFunction>();
- EXPECT_TRUE(info.empty());
+ TEST_STATE(expr.c_str());
+ CellTypeSpace just_double({CellType::DOUBLE}, 2);
+ EvalFixture::verify<FunInfo>(expr, {}, just_double);
}
TEST("require that multi matmul can be optimized") {
- TEST_DO(verify_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,sum,d)", 2, 3, 5, 6, true, true));
+ FunInfo details = { .lhs_size = 2, .common_size = 3, .rhs_size = 5,
+ .matmul_cnt = 6, .lhs_inner = true, .rhs_inner = true };
+ TEST_DO(verify_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,sum,d)", details));
+ TEST_DO(verify_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,sum,d)", details));
}
TEST("require that single multi matmul can be optimized") {
- TEST_DO(verify_optimized("reduce(A1a2d3*A1b5d3,sum,d)", 2, 3, 5, 1, true, true));
+ FunInfo details = { .lhs_size = 2, .common_size = 3, .rhs_size = 5,
+ .matmul_cnt = 1, .lhs_inner = true, .rhs_inner = true };
+ TEST_DO(verify_optimized("reduce(A1a2d3*A1b5d3,sum,d)", details));
}
TEST("require that multi matmul with lambda can be optimized") {
- TEST_DO(verify_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(x*y)),sum,d)", 2, 3, 5, 6, true, true));
+ FunInfo details = { .lhs_size = 2, .common_size = 3, .rhs_size = 5,
+ .matmul_cnt = 6, .lhs_inner = true, .rhs_inner = true };
+ TEST_DO(verify_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(x*y)),sum,d)", details));
}
TEST("require that expressions similar to multi matmul are not optimized") {
@@ -97,11 +84,6 @@ TEST("require that expressions similar to multi matmul are not optimized") {
TEST_DO(verify_not_optimized("reduce(A2B1C3a2c3*A2B1C3b5d3,sum,c)"));
}
-TEST("require that multi matmul must have matching cell type") {
- TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3_f*A2B1C3b5d3,sum,d)"));
- TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3_f,sum,d)"));
-}
-
TEST("require that multi matmul must have matching dimension prefix") {
TEST_DO(verify_not_optimized("reduce(B1C3a2d3*A2B1C3b5d3,sum,d)"));
TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*B1C3b5d3,sum,d)"));
@@ -113,51 +95,37 @@ TEST("require that multi matmul must have inner nesting of matmul dimensions") {
}
TEST("require that multi matmul ignores trivial dimensions") {
- TEST_DO(verify_optimized("reduce(A2B1C3D1a2c1d3e1*A2B1C3b5d3,sum,d)", 2, 3, 5, 6, true, true));
- TEST_DO(verify_optimized("reduce(A2B1C3b5d3*A2B1C3D1a2c1d3e1,sum,d)", 2, 3, 5, 6, true, true));
+ FunInfo details = { .lhs_size = 2, .common_size = 3, .rhs_size = 5,
+ .matmul_cnt = 6, .lhs_inner = true, .rhs_inner = true };
+ TEST_DO(verify_optimized("reduce(A2B1C3D1a2c1d3e1*A2B1C3b5d3,sum,d)", details));
+ TEST_DO(verify_optimized("reduce(A2B1C3b5d3*A2B1C3D1a2c1d3e1,sum,d)", details));
}
TEST("require that multi matmul function can be debug dumped") {
- EvalFixture fixture(prod_factory, "reduce(A2B1C3a2d3*A2B1C3b5d3,sum,d)", param_repo, true);
+ EvalFixture fixture(prod_factory, "reduce(m1*m2,sum,d)", EvalFixture::ParamRepo()
+ .add("m1", GenSpec::from_desc("A2B1C3a2d3"))
+ .add("m2", GenSpec::from_desc("A2B1C3b5d3")), true);
auto info = fixture.find_all<DenseMultiMatMulFunction>();
ASSERT_EQUAL(info.size(), 1u);
fprintf(stderr, "%s\n", info[0]->as_string().c_str());
}
-vespalib::string make_expr(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common,
- bool float_cells)
-{
- return make_string("reduce(%s%s*%s%s,sum,%s)", a.c_str(), float_cells ? "_f" : "", b.c_str(), float_cells ? "_f" : "", common.c_str());
-}
-
-void verify_optimized_multi(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common,
- size_t lhs_size, size_t common_size, size_t rhs_size, size_t matmul_cnt,
- bool lhs_inner, bool rhs_inner)
-{
- for (bool float_cells: {false, true}) {
- {
- auto expr = make_expr(a, b, common, float_cells);
- TEST_STATE(expr.c_str());
- TEST_DO(verify_optimized(expr, lhs_size, common_size, rhs_size, matmul_cnt, lhs_inner, rhs_inner));
- }
- {
- auto expr = make_expr(b, a, common, float_cells);
- TEST_STATE(expr.c_str());
- TEST_DO(verify_optimized(expr, lhs_size, common_size, rhs_size, matmul_cnt, lhs_inner, rhs_inner));
- }
- }
-}
-
TEST("require that multi matmul inner/inner works correctly") {
- TEST_DO(verify_optimized_multi("A2B1C3a2d3", "A2B1C3b5d3", "d", 2, 3, 5, 6, true, true));
+ FunInfo details = { .lhs_size = 2, .common_size = 3, .rhs_size = 5,
+ .matmul_cnt = 6, .lhs_inner = true, .rhs_inner = true };
+ TEST_DO(verify_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,sum,d)", details));
}
TEST("require that multi matmul inner/outer works correctly") {
- TEST_DO(verify_optimized_multi("A2B1C3a2b5", "A2B1C3b5d3", "b", 2, 5, 3, 6, true, false));
+ FunInfo details = { .lhs_size = 2, .common_size = 5, .rhs_size = 3,
+ .matmul_cnt = 6, .lhs_inner = true, .rhs_inner = false };
+ TEST_DO(verify_optimized("reduce(A2B1C3a2b5*A2B1C3b5d3,sum,b)", details));
}
TEST("require that multi matmul outer/outer works correctly") {
- TEST_DO(verify_optimized_multi("A2B1C3b5c2", "A2B1C3b5d3", "b", 2, 5, 3, 6, false, false));
+ FunInfo details = { .lhs_size = 2, .common_size = 5, .rhs_size = 3,
+ .matmul_cnt = 6, .lhs_inner = false, .rhs_inner = false };
+ TEST_DO(verify_optimized("reduce(A2B1C3b5c2*A2B1C3b5d3,sum,b)", details));
}
TEST_MAIN() { TEST_RUN_ALL(); }