summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-03-17 12:53:29 +0000
committerArne Juul <arnej@verizonmedia.com>2021-03-17 14:07:56 +0000
commit12e18f4ae9bb756066f74d7076253d5ad3803987 (patch)
tree03a10e8ef8a1fa53a92139be5205ab2b7b72126c /eval
parenta02f7e628be4196a83f7ab0e379c74a223eefc93 (diff)
rewrite DenseMatMulFunction test using EvalFixture::verify<FunInfo>
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/instruction/dense_matmul_function/dense_matmul_function_test.cpp111
1 files changed, 45 insertions, 66 deletions
diff --git a/eval/src/tests/instruction/dense_matmul_function/dense_matmul_function_test.cpp b/eval/src/tests/instruction/dense_matmul_function/dense_matmul_function_test.cpp
index bcc5e6c7b5d..438424fdf4e 100644
--- a/eval/src/tests/instruction/dense_matmul_function/dense_matmul_function_test.cpp
+++ b/eval/src/tests/instruction/dense_matmul_function/dense_matmul_function_test.cpp
@@ -8,7 +8,6 @@
#include <vespa/eval/instruction/dense_matmul_function.h>
#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/vespalib/util/stash.h>
-#include <vespa/vespalib/util/stringfmt.h>
using namespace vespalib;
using namespace vespalib::eval;
@@ -17,50 +16,45 @@ using namespace vespalib::eval::tensor_function;
const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();
-EvalFixture::ParamRepo make_params() {
- return EvalFixture::ParamRepo()
- .add_variants("a2d3", GenSpec().idx("a", 2).idx("d", 3)) // inner/inner
- .add_variants("a2b5", GenSpec().idx("a", 2).idx("b", 5)) // inner/outer
- .add_variants("b5c2", GenSpec().idx("b", 5).idx("c", 2)) // outer/outer
- .add_variants("a2c3", GenSpec().idx("a", 2).idx("c", 3)) // not matching
- //------------------------------------------
- .add_variants("b5d3", GenSpec().idx("b", 5).idx("d", 3)); // fixed param
-}
-EvalFixture::ParamRepo param_repo = make_params();
+struct FunInfo {
+ using LookFor = DenseMatMulFunction;
+ size_t lhs_size;
+ size_t common_size;
+ size_t rhs_size;
+ bool lhs_inner;
+ bool rhs_inner;
+ void verify(const LookFor &fun) const {
+ EXPECT_TRUE(fun.result_is_mutable());
+ EXPECT_EQUAL(fun.lhs_size(), lhs_size);
+ EXPECT_EQUAL(fun.common_size(), common_size);
+ EXPECT_EQUAL(fun.rhs_size(), rhs_size);
+ EXPECT_EQUAL(fun.lhs_common_inner(), lhs_inner);
+ EXPECT_EQUAL(fun.rhs_common_inner(), rhs_inner);
+ }
+};
-void verify_optimized(const vespalib::string &expr,
- size_t lhs_size, size_t common_size, size_t rhs_size,
- bool lhs_inner, bool rhs_inner)
-{
- EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
- EvalFixture fixture(prod_factory, expr, param_repo, true);
- EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
- EXPECT_EQUAL(fixture.result(), slow_fixture.result());
- auto info = fixture.find_all<DenseMatMulFunction>();
- ASSERT_EQUAL(info.size(), 1u);
- EXPECT_TRUE(info[0]->result_is_mutable());
- EXPECT_EQUAL(info[0]->lhs_size(), lhs_size);
- EXPECT_EQUAL(info[0]->common_size(), common_size);
- EXPECT_EQUAL(info[0]->rhs_size(), rhs_size);
- EXPECT_EQUAL(info[0]->lhs_common_inner(), lhs_inner);
- EXPECT_EQUAL(info[0]->rhs_common_inner(), rhs_inner);
+void verify_optimized(const vespalib::string &expr, FunInfo details) {
+ TEST_STATE(expr.c_str());
+ CellTypeSpace all_types(CellTypeUtils::list_types(), 2);
+ EvalFixture::verify<FunInfo>(expr, {details}, all_types);
}
void verify_not_optimized(const vespalib::string &expr) {
- EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
- EvalFixture fixture(prod_factory, expr, param_repo, true);
- EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
- EXPECT_EQUAL(fixture.result(), slow_fixture.result());
- auto info = fixture.find_all<DenseMatMulFunction>();
- EXPECT_TRUE(info.empty());
+ TEST_STATE(expr.c_str());
+ CellTypeSpace just_double({CellType::DOUBLE}, 2);
+ EvalFixture::verify<FunInfo>(expr, {}, just_double);
}
TEST("require that matmul can be optimized") {
- TEST_DO(verify_optimized("reduce(a2d3*b5d3,sum,d)", 2, 3, 5, true, true));
+ FunInfo details = { .lhs_size = 2, .common_size = 3, .rhs_size = 5,
+ .lhs_inner = true, .rhs_inner = true };
+ TEST_DO(verify_optimized("reduce(a2d3*b5d3,sum,d)", details));
}
TEST("require that matmul with lambda can be optimized") {
- TEST_DO(verify_optimized("reduce(join(a2d3,b5d3,f(x,y)(x*y)),sum,d)", 2, 3, 5, true, true));
+ FunInfo details = { .lhs_size = 2, .common_size = 3, .rhs_size = 5,
+ .lhs_inner = true, .rhs_inner = true };
+ TEST_DO(verify_optimized("reduce(join(a2d3,b5d3,f(x,y)(x*y)),sum,d)", details));
}
TEST("require that expressions similar to matmul are not optimized") {
@@ -77,49 +71,34 @@ TEST("require that expressions similar to matmul are not optimized") {
TEST_DO(verify_not_optimized("reduce(a2c3*b5d3,sum,c)"));
}
-TEST("require that xw product can be debug dumped") {
- EvalFixture fixture(prod_factory, "reduce(a2d3*b5d3,sum,d)", param_repo, true);
+TEST("require that MatMul can be debug dumped") {
+ EvalFixture fixture(prod_factory, "reduce(x*y,sum,d)", EvalFixture::ParamRepo()
+ .add("x", GenSpec::from_desc("a2d3"))
+ .add("y", GenSpec::from_desc("b5d3")), true);
auto info = fixture.find_all<DenseMatMulFunction>();
ASSERT_EQUAL(info.size(), 1u);
fprintf(stderr, "%s\n", info[0]->as_string().c_str());
}
-vespalib::string make_expr(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common,
- bool float_a, bool float_b)
-{
- return make_string("reduce(%s%s*%s%s,sum,%s)", a.c_str(), float_a ? "_f" : "", b.c_str(), float_b ? "_f" : "", common.c_str());
-}
-
-void verify_optimized_multi(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common,
- size_t lhs_size, size_t common_size, size_t rhs_size,
- bool lhs_inner, bool rhs_inner)
-{
- for (bool float_a: {false, true}) {
- for (bool float_b: {false, true}) {
- {
- auto expr = make_expr(a, b, common, float_a, float_b);
- TEST_STATE(expr.c_str());
- TEST_DO(verify_optimized(expr, lhs_size, common_size, rhs_size, lhs_inner, rhs_inner));
- }
- {
- auto expr = make_expr(b, a, common, float_b, float_a);
- TEST_STATE(expr.c_str());
- TEST_DO(verify_optimized(expr, lhs_size, common_size, rhs_size, lhs_inner, rhs_inner));
- }
- }
- }
-}
-
TEST("require that matmul inner/inner works correctly") {
- TEST_DO(verify_optimized_multi("a2d3", "b5d3", "d", 2, 3, 5, true, true));
+ FunInfo details = { .lhs_size = 2, .common_size = 3, .rhs_size = 5,
+ .lhs_inner = true, .rhs_inner = true };
+ TEST_DO(verify_optimized("reduce(a2d3*b5d3,sum,d)", details));
+ TEST_DO(verify_optimized("reduce(b5d3*a2d3,sum,d)", details));
}
TEST("require that matmul inner/outer works correctly") {
- TEST_DO(verify_optimized_multi("a2b5", "b5d3", "b", 2, 5, 3, true, false));
+ FunInfo details = { .lhs_size = 2, .common_size = 5, .rhs_size = 3,
+ .lhs_inner = true, .rhs_inner = false };
+ TEST_DO(verify_optimized("reduce(a2b5*b5d3,sum,b)", details));
+ TEST_DO(verify_optimized("reduce(b5d3*a2b5,sum,b)", details));
}
TEST("require that matmul outer/outer works correctly") {
- TEST_DO(verify_optimized_multi("b5c2", "b5d3", "b", 2, 5, 3, false, false));
+ FunInfo details = { .lhs_size = 2, .common_size = 5, .rhs_size = 3,
+ .lhs_inner = false, .rhs_inner = false };
+ TEST_DO(verify_optimized("reduce(b5c2*b5d3,sum,b)", details));
+ TEST_DO(verify_optimized("reduce(b5d3*b5c2,sum,b)", details));
}
TEST_MAIN() { TEST_RUN_ALL(); }