summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2021-03-17 15:04:35 +0100
committerGitHub <noreply@github.com>2021-03-17 15:04:35 +0100
commit48ff630a4bc00a32dd2897b441d449f646c1e927 (patch)
treef281f6fb123793bf524ef805091b5bf876bbfcbc /eval
parent8654480fafe4d18a8251322bde8cfb0ba11cfa93 (diff)
parentc2a4ee37207ad312db429aac40ec40b29afd01cc (diff)
Merge pull request #17006 from vespa-engine/arnej/update_dense_dot_product
rewrite DenseDotProductFunction test
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/instruction/dense_dot_product_function/dense_dot_product_function_test.cpp87
1 files changed, 39 insertions, 48 deletions
diff --git a/eval/src/tests/instruction/dense_dot_product_function/dense_dot_product_function_test.cpp b/eval/src/tests/instruction/dense_dot_product_function/dense_dot_product_function_test.cpp
index 2d85627e331..8421564f38a 100644
--- a/eval/src/tests/instruction/dense_dot_product_function/dense_dot_product_function_test.cpp
+++ b/eval/src/tests/instruction/dense_dot_product_function/dense_dot_product_function_test.cpp
@@ -85,74 +85,65 @@ TEST("require that dot product with equal sizes is correct") {
//-----------------------------------------------------------------------------
-EvalFixture::ParamRepo make_params() {
- return EvalFixture::ParamRepo()
- .add("v01_x1", GenSpec(2.0).idx("x", 1))
- .add("v02_x3", GenSpec(4.0).idx("x", 3))
- .add("v03_x3", GenSpec(5.0).idx("x", 3))
- .add("v04_y3", GenSpec(10).idx("y", 3))
- .add("v05_x5", GenSpec(6.0).idx("x", 5))
- .add("v06_x5", GenSpec(7.0).idx("x", 5))
- .add("v07_x5f", GenSpec(7.0).cells_float().idx("x", 5))
- .add("v08_x5f", GenSpec(6.0).cells_float().idx("x", 5))
- .add("m01_x3y3", GenSpec(1.0).idx("x", 3).idx("y", 3))
- .add("m02_x3y3", GenSpec(2.0).idx("x", 3).idx("y", 3));
-}
-EvalFixture::ParamRepo param_repo = make_params();
+struct FunInfo {
+ using LookFor = DenseDotProductFunction;
+ void verify(const LookFor &fun) const {
+ EXPECT_TRUE(fun.result_is_mutable());
+ }
+};
+
void assertOptimized(const vespalib::string &expr) {
- EvalFixture fixture(prod_factory, expr, param_repo, true);
- EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
- auto info = fixture.find_all<DenseDotProductFunction>();
- ASSERT_EQUAL(info.size(), 1u);
- EXPECT_TRUE(info[0]->result_is_mutable());
+ TEST_STATE(expr.c_str());
+ auto all_types = CellTypeSpace(CellTypeUtils::list_types(), 2);
+ EvalFixture::verify<FunInfo>(expr, {FunInfo{}}, all_types);
+
}
void assertNotOptimized(const vespalib::string &expr) {
- EvalFixture fixture(prod_factory, expr, param_repo, true);
- EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
- auto info = fixture.find_all<DenseDotProductFunction>();
- EXPECT_TRUE(info.empty());
+ TEST_STATE(expr.c_str());
+ CellTypeSpace just_double({CellType::DOUBLE}, 2);
+ EvalFixture::verify<FunInfo>(expr, {}, just_double);
}
TEST("require that dot product works with tensor function") {
- TEST_DO(assertOptimized("reduce(v05_x5*v06_x5,sum)"));
- TEST_DO(assertOptimized("reduce(v05_x5*v06_x5,sum,x)"));
- TEST_DO(assertOptimized("reduce(join(v05_x5,v06_x5,f(x,y)(x*y)),sum)"));
- TEST_DO(assertOptimized("reduce(join(v05_x5,v06_x5,f(x,y)(x*y)),sum,x)"));
+ TEST_DO(assertOptimized("reduce(x5$1*x5$2,sum)"));
+ TEST_DO(assertOptimized("reduce(x5$1*x5$2,sum,x)"));
+ TEST_DO(assertOptimized("reduce(join(x5$1,x5$2,f(x,y)(x*y)),sum)"));
+ TEST_DO(assertOptimized("reduce(join(x5$1,x5$2,f(x,y)(x*y)),sum,x)"));
}
TEST("require that dot product with compatible dimensions is optimized") {
- TEST_DO(assertOptimized("reduce(v01_x1*v01_x1,sum)"));
- TEST_DO(assertOptimized("reduce(v02_x3*v03_x3,sum)"));
- TEST_DO(assertOptimized("reduce(v05_x5*v06_x5,sum)"));
+ TEST_DO(assertOptimized("reduce(x1$1*x1$2,sum)"));
+ TEST_DO(assertOptimized("reduce(x3$1*x3$2,sum)"));
+ TEST_DO(assertOptimized("reduce(x5$1*x5$2,sum)"));
}
TEST("require that dot product with incompatible dimensions is NOT optimized") {
- TEST_DO(assertNotOptimized("reduce(v02_x3*v04_y3,sum)"));
- TEST_DO(assertNotOptimized("reduce(v04_y3*v02_x3,sum)"));
- TEST_DO(assertNotOptimized("reduce(v02_x3*m01_x3y3,sum)"));
- TEST_DO(assertNotOptimized("reduce(m01_x3y3*v02_x3,sum)"));
+ TEST_DO(assertNotOptimized("reduce(x3*y3,sum)"));
+ TEST_DO(assertNotOptimized("reduce(y3*x3,sum)"));
+ TEST_DO(assertNotOptimized("reduce(x3*x3y3,sum)"));
+ TEST_DO(assertNotOptimized("reduce(x3y3*x3,sum)"));
}
TEST("require that expressions similar to dot product are not optimized") {
- TEST_DO(assertNotOptimized("reduce(v02_x3*v03_x3,prod)"));
- TEST_DO(assertNotOptimized("reduce(v02_x3+v03_x3,sum)"));
- TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(x+y)),sum)"));
- TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(x*x)),sum)"));
- TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(y*y)),sum)"));
- // TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(y*x)),sum)"));
+ TEST_DO(assertNotOptimized("reduce(x3$1*x3$2,prod)"));
+ TEST_DO(assertNotOptimized("reduce(x3$1+x3$2,sum)"));
+ TEST_DO(assertNotOptimized("reduce(join(x3$1,x3$2,f(x,y)(x+y)),sum)"));
+ TEST_DO(assertNotOptimized("reduce(join(x3$1,x3$2,f(x,y)(x*x)),sum)"));
+ TEST_DO(assertNotOptimized("reduce(join(x3$1,x3$2,f(x,y)(y*y)),sum)"));
+ // TEST_DO(assertNotOptimized("reduce(join(x3$1,x3$2,f(x,y)(y*x)),sum)"));
}
TEST("require that multi-dimensional dot product can be optimized") {
- TEST_DO(assertOptimized("reduce(m01_x3y3*m02_x3y3,sum)"));
- TEST_DO(assertOptimized("reduce(m02_x3y3*m01_x3y3,sum)"));
+ TEST_DO(assertOptimized("reduce(x3y3$1*x3y3$2,sum)"));
+ TEST_DO(assertOptimized("reduce(x3y3$1*x3y3$2,sum)"));
}
TEST("require that result must be double to trigger optimization") {
- TEST_DO(assertOptimized("reduce(m01_x3y3*m01_x3y3,sum,x,y)"));
- TEST_DO(assertNotOptimized("reduce(m01_x3y3*m01_x3y3,sum,x)"));
- TEST_DO(assertNotOptimized("reduce(m01_x3y3*m01_x3y3,sum,y)"));
+ TEST_DO(assertOptimized("reduce(x3y3$1*x3y3$2,sum,x,y)"));
+ TEST_DO(assertNotOptimized("reduce(x3y3$1*x3y3$2,sum,x)"));
+ TEST_DO(assertNotOptimized("reduce(x3y3$1*x3y3$2,sum,y)"));
}
void verify_compatible(const vespalib::string &a, const vespalib::string &b) {
@@ -186,9 +177,9 @@ TEST("require that type compatibility test is appropriate") {
}
TEST("require that optimization also works for tensors with non-double cells") {
- TEST_DO(assertOptimized("reduce(v05_x5*v07_x5f,sum)"));
- TEST_DO(assertOptimized("reduce(v07_x5f*v05_x5,sum)"));
- TEST_DO(assertOptimized("reduce(v07_x5f*v08_x5f,sum)"));
+ TEST_DO(assertOptimized("reduce(x5$1*x5$2,sum)"));
+ TEST_DO(assertOptimized("reduce(x5$1*x5$2,sum)"));
+ TEST_DO(assertOptimized("reduce(x5$1*x5$2,sum)"));
}
//-----------------------------------------------------------------------------