summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp')
-rw-r--r--eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp56
1 files changed, 3 insertions, 53 deletions
diff --git a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
index 60830e4abd7..fae5db75618 100644
--- a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
@@ -62,11 +62,6 @@ TEST("require that basic dot product with equal sizes is correct") {
check_gen_with_result(2, 2, (3.0 * 5.0) + (4.0 * 6.0));
}
-TEST("require that basic dot product with un-equal sizes is correct") {
- check_gen_with_result(2, 3, (3.0 * 5.0) + (4.0 * 6.0));
- check_gen_with_result(3, 2, (3.0 * 5.0) + (4.0 * 6.0));
-}
-
//-----------------------------------------------------------------------------
void assertDotProduct(size_t numCells) {
@@ -98,18 +93,6 @@ TEST("require that dot product with equal sizes is correct") {
TEST_DO(assertDotProduct(1024 + 3));
}
-TEST("require that dot product with un-equal sizes is correct") {
- TEST_DO(assertDotProduct(8, 8 + 3));
- TEST_DO(assertDotProduct(8 + 3, 8));
- TEST_DO(assertDotProduct(16, 16 + 3));
- TEST_DO(assertDotProduct(32, 32 + 3));
- TEST_DO(assertDotProduct(64, 64 + 3));
- TEST_DO(assertDotProduct(128, 128 + 3));
- TEST_DO(assertDotProduct(256, 256 + 3));
- TEST_DO(assertDotProduct(512, 512 + 3));
- TEST_DO(assertDotProduct(1024, 1024 + 3));
-}
-
//-----------------------------------------------------------------------------
EvalFixture::ParamRepo make_params() {
@@ -120,14 +103,8 @@ EvalFixture::ParamRepo make_params() {
.add("v04_y3", spec({y(3)}, MyVecSeq(10)))
.add("v05_x5", spec({x(5)}, MyVecSeq(6.0)))
.add("v06_x5", spec({x(5)}, MyVecSeq(7.0)))
- .add("v07_x3_a", spec({x(3)}, MyVecSeq(8.0)), "any")
- .add("v08_x3_u", spec({x(3)}, MyVecSeq(9.0)), "tensor(x[])")
- .add("v09_x4_u", spec({x(4)}, MyVecSeq(3.0)), "tensor(x[])")
.add("m01_x3y3", spec({x(3),y(3)}, MyVecSeq(1.0)))
- .add("m02_x2y3", spec({x(2),y(3)}, MyVecSeq(2.0)))
- .add("m03_x3y2", spec({x(3),y(2)}, MyVecSeq(3.0)))
- .add("m04_xuy3", spec({x(3),y(3)}, MyVecSeq(4.0)), "tensor(x[],y[3])")
- .add("m05_x3yu", spec({x(3),y(3)}, MyVecSeq(5.0)), "tensor(x[3],y[])");
+ .add("m02_x3y3", spec({x(3),y(3)}, MyVecSeq(2.0)));
}
EvalFixture::ParamRepo param_repo = make_params();
@@ -146,11 +123,6 @@ void assertNotOptimized(const vespalib::string &expr) {
EXPECT_TRUE(info.empty());
}
-TEST("require that dot product is not optimized for unknown types") {
- TEST_DO(assertNotOptimized("reduce(v02_x3*v07_x3_a,sum)"));
- TEST_DO(assertNotOptimized("reduce(v07_x3_a*v03_x3,sum)"));
-}
-
TEST("require that dot product works with tensor function") {
TEST_DO(assertOptimized("reduce(v05_x5*v06_x5,sum)"));
TEST_DO(assertOptimized("reduce(v05_x5*v06_x5,sum,x)"));
@@ -162,18 +134,11 @@ TEST("require that dot product with compatible dimensions is optimized") {
TEST_DO(assertOptimized("reduce(v01_x1*v01_x1,sum)"));
TEST_DO(assertOptimized("reduce(v02_x3*v03_x3,sum)"));
TEST_DO(assertOptimized("reduce(v05_x5*v06_x5,sum)"));
-
- TEST_DO(assertOptimized("reduce(v02_x3*v06_x5,sum)"));
- TEST_DO(assertOptimized("reduce(v05_x5*v03_x3,sum)"));
- TEST_DO(assertOptimized("reduce(v08_x3_u*v05_x5,sum)"));
- TEST_DO(assertOptimized("reduce(v05_x5*v08_x3_u,sum)"));
}
TEST("require that dot product with incompatible dimensions is NOT optimized") {
TEST_DO(assertNotOptimized("reduce(v02_x3*v04_y3,sum)"));
TEST_DO(assertNotOptimized("reduce(v04_y3*v02_x3,sum)"));
- TEST_DO(assertNotOptimized("reduce(v08_x3_u*v04_y3,sum)"));
- TEST_DO(assertNotOptimized("reduce(v04_y3*v08_x3_u,sum)"));
TEST_DO(assertNotOptimized("reduce(v02_x3*m01_x3y3,sum)"));
TEST_DO(assertNotOptimized("reduce(m01_x3y3*v02_x3,sum)"));
}
@@ -188,11 +153,8 @@ TEST("require that expressions similar to dot product are not optimized") {
}
TEST("require that multi-dimensional dot product can be optimized") {
- TEST_DO(assertOptimized("reduce(m01_x3y3*m02_x2y3,sum)"));
- TEST_DO(assertOptimized("reduce(m02_x2y3*m01_x3y3,sum)"));
- TEST_DO(assertOptimized("reduce(m01_x3y3*m04_xuy3,sum)"));
- TEST_DO(assertOptimized("reduce(m04_xuy3*m01_x3y3,sum)"));
- TEST_DO(assertOptimized("reduce(m04_xuy3*m04_xuy3,sum)"));
+ TEST_DO(assertOptimized("reduce(m01_x3y3*m02_x3y3,sum)"));
+ TEST_DO(assertOptimized("reduce(m02_x3y3*m01_x3y3,sum)"));
}
TEST("require that result must be double to trigger optimization") {
@@ -201,14 +163,6 @@ TEST("require that result must be double to trigger optimization") {
TEST_DO(assertNotOptimized("reduce(m01_x3y3*m01_x3y3,sum,y)"));
}
-TEST("require that additional dimensions must have matching size") {
- TEST_DO(assertOptimized("reduce(m01_x3y3*m01_x3y3,sum)"));
- TEST_DO(assertNotOptimized("reduce(m01_x3y3*m03_x3y2,sum)"));
- TEST_DO(assertNotOptimized("reduce(m03_x3y2*m01_x3y3,sum)"));
- TEST_DO(assertNotOptimized("reduce(m01_x3y3*m05_x3yu,sum)"));
- TEST_DO(assertNotOptimized("reduce(m05_x3yu*m01_x3y3,sum)"));
-}
-
void verify_compatible(const vespalib::string &a, const vespalib::string &b) {
auto a_type = ValueType::from_spec(a);
auto b_type = ValueType::from_spec(b);
@@ -231,13 +185,9 @@ TEST("require that type compatibility test is appropriate") {
TEST_DO(verify_compatible("tensor(x[5])", "tensor(x[5])"));
TEST_DO(verify_not_compatible("tensor(x[5])", "tensor(y[5])"));
TEST_DO(verify_compatible("tensor(x[5])", "tensor(x[3])"));
- TEST_DO(verify_compatible("tensor(x[])", "tensor(x[3])"));
TEST_DO(verify_compatible("tensor(x[3],y[7],z[9])", "tensor(x[5],y[7],z[9])"));
- TEST_DO(verify_compatible("tensor(x[3],y[7],z[9])", "tensor(x[],y[7],z[9])"));
TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[9])", "tensor(x[5],y[5],z[9])"));
- TEST_DO(verify_not_compatible("tensor(x[5],y[],z[9])", "tensor(x[5],y[7],z[9])"));
TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[9])", "tensor(x[5],y[7],z[5])"));
- TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[])", "tensor(x[5],y[7],z[9])"));
}
//-----------------------------------------------------------------------------