diff options
Diffstat (limited to 'eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp')
-rw-r--r-- | eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp | 59 |
1 files changed, 58 insertions, 1 deletions
diff --git a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp index 37f9602565d..60830e4abd7 100644 --- a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp +++ b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp @@ -123,7 +123,11 @@ EvalFixture::ParamRepo make_params() { .add("v07_x3_a", spec({x(3)}, MyVecSeq(8.0)), "any") .add("v08_x3_u", spec({x(3)}, MyVecSeq(9.0)), "tensor(x[])") .add("v09_x4_u", spec({x(4)}, MyVecSeq(3.0)), "tensor(x[])") - .add("m01_x3y3", spec({x(3),y(3)}, MyVecSeq(0))); + .add("m01_x3y3", spec({x(3),y(3)}, MyVecSeq(1.0))) + .add("m02_x2y3", spec({x(2),y(3)}, MyVecSeq(2.0))) + .add("m03_x3y2", spec({x(3),y(2)}, MyVecSeq(3.0))) + .add("m04_xuy3", spec({x(3),y(3)}, MyVecSeq(4.0)), "tensor(x[],y[3])") + .add("m05_x3yu", spec({x(3),y(3)}, MyVecSeq(5.0)), "tensor(x[3],y[])"); } EvalFixture::ParamRepo param_repo = make_params(); @@ -183,6 +187,59 @@ TEST("require that expressions similar to dot product are not optimized") { // TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(y*x)),sum)")); } +TEST("require that multi-dimensional dot product can be optimized") { + TEST_DO(assertOptimized("reduce(m01_x3y3*m02_x2y3,sum)")); + TEST_DO(assertOptimized("reduce(m02_x2y3*m01_x3y3,sum)")); + TEST_DO(assertOptimized("reduce(m01_x3y3*m04_xuy3,sum)")); + TEST_DO(assertOptimized("reduce(m04_xuy3*m01_x3y3,sum)")); + TEST_DO(assertOptimized("reduce(m04_xuy3*m04_xuy3,sum)")); +} + +TEST("require that result must be double to trigger optimization") { + TEST_DO(assertOptimized("reduce(m01_x3y3*m01_x3y3,sum,x,y)")); + TEST_DO(assertNotOptimized("reduce(m01_x3y3*m01_x3y3,sum,x)")); + TEST_DO(assertNotOptimized("reduce(m01_x3y3*m01_x3y3,sum,y)")); +} + +TEST("require that additional dimensions must have matching size") { + TEST_DO(assertOptimized("reduce(m01_x3y3*m01_x3y3,sum)")); + TEST_DO(assertNotOptimized("reduce(m01_x3y3*m03_x3y2,sum)")); + TEST_DO(assertNotOptimized("reduce(m03_x3y2*m01_x3y3,sum)")); + TEST_DO(assertNotOptimized("reduce(m01_x3y3*m05_x3yu,sum)")); + TEST_DO(assertNotOptimized("reduce(m05_x3yu*m01_x3y3,sum)")); +} + +void verify_compatible(const vespalib::string &a, const vespalib::string &b) { + auto a_type = ValueType::from_spec(a); + auto b_type = ValueType::from_spec(b); + EXPECT_TRUE(!a_type.is_error()); + EXPECT_TRUE(!b_type.is_error()); + EXPECT_TRUE(DenseDotProductFunction::compatible_types(ValueType::double_type(), a_type, b_type)); + EXPECT_TRUE(DenseDotProductFunction::compatible_types(ValueType::double_type(), b_type, a_type)); +} + +void verify_not_compatible(const vespalib::string &a, const vespalib::string &b) { + auto a_type = ValueType::from_spec(a); + auto b_type = ValueType::from_spec(b); + EXPECT_TRUE(!a_type.is_error()); + EXPECT_TRUE(!b_type.is_error()); + EXPECT_TRUE(!DenseDotProductFunction::compatible_types(ValueType::double_type(), a_type, b_type)); + EXPECT_TRUE(!DenseDotProductFunction::compatible_types(ValueType::double_type(), b_type, a_type)); +} + +TEST("require that type compatibility test is appropriate") { + TEST_DO(verify_compatible("tensor(x[5])", "tensor(x[5])")); + TEST_DO(verify_not_compatible("tensor(x[5])", "tensor(y[5])")); + TEST_DO(verify_compatible("tensor(x[5])", "tensor(x[3])")); + TEST_DO(verify_compatible("tensor(x[])", "tensor(x[3])")); + TEST_DO(verify_compatible("tensor(x[3],y[7],z[9])", "tensor(x[5],y[7],z[9])")); + TEST_DO(verify_compatible("tensor(x[3],y[7],z[9])", "tensor(x[],y[7],z[9])")); + TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[9])", "tensor(x[5],y[5],z[9])")); + TEST_DO(verify_not_compatible("tensor(x[5],y[],z[9])", "tensor(x[5],y[7],z[9])")); + TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[9])", "tensor(x[5],y[7],z[5])")); + TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[])", "tensor(x[5],y[7],z[9])")); +} + //----------------------------------------------------------------------------- TEST_MAIN() { TEST_RUN_ALL(); } |