diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2021-03-17 15:04:35 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-03-17 15:04:35 +0100 |
commit | 48ff630a4bc00a32dd2897b441d449f646c1e927 (patch) | |
tree | f281f6fb123793bf524ef805091b5bf876bbfcbc /eval | |
parent | 8654480fafe4d18a8251322bde8cfb0ba11cfa93 (diff) | |
parent | c2a4ee37207ad312db429aac40ec40b29afd01cc (diff) |
Merge pull request #17006 from vespa-engine/arnej/update_dense_dot_product
rewrite DenseDotProductFunction test
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/tests/instruction/dense_dot_product_function/dense_dot_product_function_test.cpp | 87 |
1 files changed, 39 insertions, 48 deletions
diff --git a/eval/src/tests/instruction/dense_dot_product_function/dense_dot_product_function_test.cpp b/eval/src/tests/instruction/dense_dot_product_function/dense_dot_product_function_test.cpp index 2d85627e331..8421564f38a 100644 --- a/eval/src/tests/instruction/dense_dot_product_function/dense_dot_product_function_test.cpp +++ b/eval/src/tests/instruction/dense_dot_product_function/dense_dot_product_function_test.cpp @@ -85,74 +85,65 @@ TEST("require that dot product with equal sizes is correct") { //----------------------------------------------------------------------------- -EvalFixture::ParamRepo make_params() { - return EvalFixture::ParamRepo() - .add("v01_x1", GenSpec(2.0).idx("x", 1)) - .add("v02_x3", GenSpec(4.0).idx("x", 3)) - .add("v03_x3", GenSpec(5.0).idx("x", 3)) - .add("v04_y3", GenSpec(10).idx("y", 3)) - .add("v05_x5", GenSpec(6.0).idx("x", 5)) - .add("v06_x5", GenSpec(7.0).idx("x", 5)) - .add("v07_x5f", GenSpec(7.0).cells_float().idx("x", 5)) - .add("v08_x5f", GenSpec(6.0).cells_float().idx("x", 5)) - .add("m01_x3y3", GenSpec(1.0).idx("x", 3).idx("y", 3)) - .add("m02_x3y3", GenSpec(2.0).idx("x", 3).idx("y", 3)); -} -EvalFixture::ParamRepo param_repo = make_params(); +struct FunInfo { + using LookFor = DenseDotProductFunction; + void verify(const LookFor &fun) const { + EXPECT_TRUE(fun.result_is_mutable()); + } +}; + void assertOptimized(const vespalib::string &expr) { - EvalFixture fixture(prod_factory, expr, param_repo, true); - EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo)); - auto info = fixture.find_all<DenseDotProductFunction>(); - ASSERT_EQUAL(info.size(), 1u); - EXPECT_TRUE(info[0]->result_is_mutable()); + TEST_STATE(expr.c_str()); + auto all_types = CellTypeSpace(CellTypeUtils::list_types(), 2); + EvalFixture::verify<FunInfo>(expr, {FunInfo{}}, all_types); + } void assertNotOptimized(const vespalib::string &expr) { - EvalFixture fixture(prod_factory, expr, param_repo, true); - EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo)); - auto info = fixture.find_all<DenseDotProductFunction>(); - EXPECT_TRUE(info.empty()); + TEST_STATE(expr.c_str()); + CellTypeSpace just_double({CellType::DOUBLE}, 2); + EvalFixture::verify<FunInfo>(expr, {}, just_double); } TEST("require that dot product works with tensor function") { - TEST_DO(assertOptimized("reduce(v05_x5*v06_x5,sum)")); - TEST_DO(assertOptimized("reduce(v05_x5*v06_x5,sum,x)")); - TEST_DO(assertOptimized("reduce(join(v05_x5,v06_x5,f(x,y)(x*y)),sum)")); - TEST_DO(assertOptimized("reduce(join(v05_x5,v06_x5,f(x,y)(x*y)),sum,x)")); + TEST_DO(assertOptimized("reduce(x5$1*x5$2,sum)")); + TEST_DO(assertOptimized("reduce(x5$1*x5$2,sum,x)")); + TEST_DO(assertOptimized("reduce(join(x5$1,x5$2,f(x,y)(x*y)),sum)")); + TEST_DO(assertOptimized("reduce(join(x5$1,x5$2,f(x,y)(x*y)),sum,x)")); } TEST("require that dot product with compatible dimensions is optimized") { - TEST_DO(assertOptimized("reduce(v01_x1*v01_x1,sum)")); - TEST_DO(assertOptimized("reduce(v02_x3*v03_x3,sum)")); - TEST_DO(assertOptimized("reduce(v05_x5*v06_x5,sum)")); + TEST_DO(assertOptimized("reduce(x1$1*x1$2,sum)")); + TEST_DO(assertOptimized("reduce(x3$1*x3$2,sum)")); + TEST_DO(assertOptimized("reduce(x5$1*x5$2,sum)")); } TEST("require that dot product with incompatible dimensions is NOT optimized") { - TEST_DO(assertNotOptimized("reduce(v02_x3*v04_y3,sum)")); - TEST_DO(assertNotOptimized("reduce(v04_y3*v02_x3,sum)")); - TEST_DO(assertNotOptimized("reduce(v02_x3*m01_x3y3,sum)")); - TEST_DO(assertNotOptimized("reduce(m01_x3y3*v02_x3,sum)")); + TEST_DO(assertNotOptimized("reduce(x3*y3,sum)")); + TEST_DO(assertNotOptimized("reduce(y3*x3,sum)")); + TEST_DO(assertNotOptimized("reduce(x3*x3y3,sum)")); + TEST_DO(assertNotOptimized("reduce(x3y3*x3,sum)")); } TEST("require that expressions similar to dot product are not optimized") { - TEST_DO(assertNotOptimized("reduce(v02_x3*v03_x3,prod)")); - TEST_DO(assertNotOptimized("reduce(v02_x3+v03_x3,sum)")); - TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(x+y)),sum)")); - TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(x*x)),sum)")); - TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(y*y)),sum)")); - // TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(y*x)),sum)")); + TEST_DO(assertNotOptimized("reduce(x3$1*x3$2,prod)")); + TEST_DO(assertNotOptimized("reduce(x3$1+x3$2,sum)")); + TEST_DO(assertNotOptimized("reduce(join(x3$1,x3$2,f(x,y)(x+y)),sum)")); + TEST_DO(assertNotOptimized("reduce(join(x3$1,x3$2,f(x,y)(x*x)),sum)")); + TEST_DO(assertNotOptimized("reduce(join(x3$1,x3$2,f(x,y)(y*y)),sum)")); + // TEST_DO(assertNotOptimized("reduce(join(x3$1,x3$2,f(x,y)(y*x)),sum)")); } TEST("require that multi-dimensional dot product can be optimized") { - TEST_DO(assertOptimized("reduce(m01_x3y3*m02_x3y3,sum)")); - TEST_DO(assertOptimized("reduce(m02_x3y3*m01_x3y3,sum)")); + TEST_DO(assertOptimized("reduce(x3y3$1*x3y3$2,sum)")); + TEST_DO(assertOptimized("reduce(x3y3$1*x3y3$2,sum)")); } TEST("require that result must be double to trigger optimization") { - TEST_DO(assertOptimized("reduce(m01_x3y3*m01_x3y3,sum,x,y)")); - TEST_DO(assertNotOptimized("reduce(m01_x3y3*m01_x3y3,sum,x)")); - TEST_DO(assertNotOptimized("reduce(m01_x3y3*m01_x3y3,sum,y)")); + TEST_DO(assertOptimized("reduce(x3y3$1*x3y3$2,sum,x,y)")); + TEST_DO(assertNotOptimized("reduce(x3y3$1*x3y3$2,sum,x)")); + TEST_DO(assertNotOptimized("reduce(x3y3$1*x3y3$2,sum,y)")); } void verify_compatible(const vespalib::string &a, const vespalib::string &b) { @@ -186,9 +177,9 @@ TEST("require that type compatibility test is appropriate") { } TEST("require that optimization also works for tensors with non-double cells") { - TEST_DO(assertOptimized("reduce(v05_x5*v07_x5f,sum)")); - TEST_DO(assertOptimized("reduce(v07_x5f*v05_x5,sum)")); - TEST_DO(assertOptimized("reduce(v07_x5f*v08_x5f,sum)")); + TEST_DO(assertOptimized("reduce(x5$1*x5$2,sum)")); + TEST_DO(assertOptimized("reduce(x5$1*x5$2,sum)")); + TEST_DO(assertOptimized("reduce(x5$1*x5$2,sum)")); } //----------------------------------------------------------------------------- |