aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp')
-rw-r--r--eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp59
1 files changed, 58 insertions, 1 deletions
diff --git a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
index 37f9602565d..60830e4abd7 100644
--- a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
@@ -123,7 +123,11 @@ EvalFixture::ParamRepo make_params() {
.add("v07_x3_a", spec({x(3)}, MyVecSeq(8.0)), "any")
.add("v08_x3_u", spec({x(3)}, MyVecSeq(9.0)), "tensor(x[])")
.add("v09_x4_u", spec({x(4)}, MyVecSeq(3.0)), "tensor(x[])")
- .add("m01_x3y3", spec({x(3),y(3)}, MyVecSeq(0)));
+ .add("m01_x3y3", spec({x(3),y(3)}, MyVecSeq(1.0)))
+ .add("m02_x2y3", spec({x(2),y(3)}, MyVecSeq(2.0)))
+ .add("m03_x3y2", spec({x(3),y(2)}, MyVecSeq(3.0)))
+ .add("m04_xuy3", spec({x(3),y(3)}, MyVecSeq(4.0)), "tensor(x[],y[3])")
+ .add("m05_x3yu", spec({x(3),y(3)}, MyVecSeq(5.0)), "tensor(x[3],y[])");
}
EvalFixture::ParamRepo param_repo = make_params();
@@ -183,6 +187,59 @@ TEST("require that expressions similar to dot product are not optimized") {
// TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(y*x)),sum)"));
}
+TEST("require that multi-dimensional dot product can be optimized") {
+ TEST_DO(assertOptimized("reduce(m01_x3y3*m02_x2y3,sum)"));
+ TEST_DO(assertOptimized("reduce(m02_x2y3*m01_x3y3,sum)"));
+ TEST_DO(assertOptimized("reduce(m01_x3y3*m04_xuy3,sum)"));
+ TEST_DO(assertOptimized("reduce(m04_xuy3*m01_x3y3,sum)"));
+ TEST_DO(assertOptimized("reduce(m04_xuy3*m04_xuy3,sum)"));
+}
+
+TEST("require that result must be double to trigger optimization") {
+ TEST_DO(assertOptimized("reduce(m01_x3y3*m01_x3y3,sum,x,y)"));
+ TEST_DO(assertNotOptimized("reduce(m01_x3y3*m01_x3y3,sum,x)"));
+ TEST_DO(assertNotOptimized("reduce(m01_x3y3*m01_x3y3,sum,y)"));
+}
+
+TEST("require that additional dimensions must have matching size") {
+ TEST_DO(assertOptimized("reduce(m01_x3y3*m01_x3y3,sum)"));
+ TEST_DO(assertNotOptimized("reduce(m01_x3y3*m03_x3y2,sum)"));
+ TEST_DO(assertNotOptimized("reduce(m03_x3y2*m01_x3y3,sum)"));
+ TEST_DO(assertNotOptimized("reduce(m01_x3y3*m05_x3yu,sum)"));
+ TEST_DO(assertNotOptimized("reduce(m05_x3yu*m01_x3y3,sum)"));
+}
+
+void verify_compatible(const vespalib::string &a, const vespalib::string &b) {
+ auto a_type = ValueType::from_spec(a);
+ auto b_type = ValueType::from_spec(b);
+ EXPECT_TRUE(!a_type.is_error());
+ EXPECT_TRUE(!b_type.is_error());
+ EXPECT_TRUE(DenseDotProductFunction::compatible_types(ValueType::double_type(), a_type, b_type));
+ EXPECT_TRUE(DenseDotProductFunction::compatible_types(ValueType::double_type(), b_type, a_type));
+}
+
+void verify_not_compatible(const vespalib::string &a, const vespalib::string &b) {
+ auto a_type = ValueType::from_spec(a);
+ auto b_type = ValueType::from_spec(b);
+ EXPECT_TRUE(!a_type.is_error());
+ EXPECT_TRUE(!b_type.is_error());
+ EXPECT_TRUE(!DenseDotProductFunction::compatible_types(ValueType::double_type(), a_type, b_type));
+ EXPECT_TRUE(!DenseDotProductFunction::compatible_types(ValueType::double_type(), b_type, a_type));
+}
+
+TEST("require that type compatibility test is appropriate") {
+ TEST_DO(verify_compatible("tensor(x[5])", "tensor(x[5])"));
+ TEST_DO(verify_not_compatible("tensor(x[5])", "tensor(y[5])"));
+ TEST_DO(verify_compatible("tensor(x[5])", "tensor(x[3])"));
+ TEST_DO(verify_compatible("tensor(x[])", "tensor(x[3])"));
+ TEST_DO(verify_compatible("tensor(x[3],y[7],z[9])", "tensor(x[5],y[7],z[9])"));
+ TEST_DO(verify_compatible("tensor(x[3],y[7],z[9])", "tensor(x[],y[7],z[9])"));
+ TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[9])", "tensor(x[5],y[5],z[9])"));
+ TEST_DO(verify_not_compatible("tensor(x[5],y[],z[9])", "tensor(x[5],y[7],z[9])"));
+ TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[9])", "tensor(x[5],y[7],z[5])"));
+ TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[])", "tensor(x[5],y[7],z[9])"));
+}
+
//-----------------------------------------------------------------------------
TEST_MAIN() { TEST_RUN_ALL(); }