diff options
Diffstat (limited to 'eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp')
-rw-r--r-- | eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp | 56 |
1 files changed, 3 insertions, 53 deletions
diff --git a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp index 60830e4abd7..fae5db75618 100644 --- a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp +++ b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp @@ -62,11 +62,6 @@ TEST("require that basic dot product with equal sizes is correct") { check_gen_with_result(2, 2, (3.0 * 5.0) + (4.0 * 6.0)); } -TEST("require that basic dot product with un-equal sizes is correct") { - check_gen_with_result(2, 3, (3.0 * 5.0) + (4.0 * 6.0)); - check_gen_with_result(3, 2, (3.0 * 5.0) + (4.0 * 6.0)); -} - //----------------------------------------------------------------------------- void assertDotProduct(size_t numCells) { @@ -98,18 +93,6 @@ TEST("require that dot product with equal sizes is correct") { TEST_DO(assertDotProduct(1024 + 3)); } -TEST("require that dot product with un-equal sizes is correct") { - TEST_DO(assertDotProduct(8, 8 + 3)); - TEST_DO(assertDotProduct(8 + 3, 8)); - TEST_DO(assertDotProduct(16, 16 + 3)); - TEST_DO(assertDotProduct(32, 32 + 3)); - TEST_DO(assertDotProduct(64, 64 + 3)); - TEST_DO(assertDotProduct(128, 128 + 3)); - TEST_DO(assertDotProduct(256, 256 + 3)); - TEST_DO(assertDotProduct(512, 512 + 3)); - TEST_DO(assertDotProduct(1024, 1024 + 3)); -} - //----------------------------------------------------------------------------- EvalFixture::ParamRepo make_params() { @@ -120,14 +103,8 @@ EvalFixture::ParamRepo make_params() { .add("v04_y3", spec({y(3)}, MyVecSeq(10))) .add("v05_x5", spec({x(5)}, MyVecSeq(6.0))) .add("v06_x5", spec({x(5)}, MyVecSeq(7.0))) - .add("v07_x3_a", spec({x(3)}, MyVecSeq(8.0)), "any") - .add("v08_x3_u", spec({x(3)}, MyVecSeq(9.0)), "tensor(x[])") - .add("v09_x4_u", spec({x(4)}, MyVecSeq(3.0)), "tensor(x[])") .add("m01_x3y3", spec({x(3),y(3)}, MyVecSeq(1.0))) - .add("m02_x2y3", spec({x(2),y(3)}, MyVecSeq(2.0))) - .add("m03_x3y2", spec({x(3),y(2)}, MyVecSeq(3.0))) - .add("m04_xuy3", spec({x(3),y(3)}, MyVecSeq(4.0)), "tensor(x[],y[3])") - .add("m05_x3yu", spec({x(3),y(3)}, MyVecSeq(5.0)), "tensor(x[3],y[])"); + .add("m02_x3y3", spec({x(3),y(3)}, MyVecSeq(2.0))); } EvalFixture::ParamRepo param_repo = make_params(); @@ -146,11 +123,6 @@ void assertNotOptimized(const vespalib::string &expr) { EXPECT_TRUE(info.empty()); } -TEST("require that dot product is not optimized for unknown types") { - TEST_DO(assertNotOptimized("reduce(v02_x3*v07_x3_a,sum)")); - TEST_DO(assertNotOptimized("reduce(v07_x3_a*v03_x3,sum)")); -} - TEST("require that dot product works with tensor function") { TEST_DO(assertOptimized("reduce(v05_x5*v06_x5,sum)")); TEST_DO(assertOptimized("reduce(v05_x5*v06_x5,sum,x)")); @@ -162,18 +134,11 @@ TEST("require that dot product with compatible dimensions is optimized") { TEST_DO(assertOptimized("reduce(v01_x1*v01_x1,sum)")); TEST_DO(assertOptimized("reduce(v02_x3*v03_x3,sum)")); TEST_DO(assertOptimized("reduce(v05_x5*v06_x5,sum)")); - - TEST_DO(assertOptimized("reduce(v02_x3*v06_x5,sum)")); - TEST_DO(assertOptimized("reduce(v05_x5*v03_x3,sum)")); - TEST_DO(assertOptimized("reduce(v08_x3_u*v05_x5,sum)")); - TEST_DO(assertOptimized("reduce(v05_x5*v08_x3_u,sum)")); } TEST("require that dot product with incompatible dimensions is NOT optimized") { TEST_DO(assertNotOptimized("reduce(v02_x3*v04_y3,sum)")); TEST_DO(assertNotOptimized("reduce(v04_y3*v02_x3,sum)")); - TEST_DO(assertNotOptimized("reduce(v08_x3_u*v04_y3,sum)")); - TEST_DO(assertNotOptimized("reduce(v04_y3*v08_x3_u,sum)")); TEST_DO(assertNotOptimized("reduce(v02_x3*m01_x3y3,sum)")); TEST_DO(assertNotOptimized("reduce(m01_x3y3*v02_x3,sum)")); } @@ -188,11 +153,8 @@ TEST("require that expressions similar to dot product are not optimized") { } TEST("require that multi-dimensional dot product can be optimized") { - TEST_DO(assertOptimized("reduce(m01_x3y3*m02_x2y3,sum)")); - TEST_DO(assertOptimized("reduce(m02_x2y3*m01_x3y3,sum)")); - TEST_DO(assertOptimized("reduce(m01_x3y3*m04_xuy3,sum)")); - TEST_DO(assertOptimized("reduce(m04_xuy3*m01_x3y3,sum)")); - TEST_DO(assertOptimized("reduce(m04_xuy3*m04_xuy3,sum)")); + TEST_DO(assertOptimized("reduce(m01_x3y3*m02_x3y3,sum)")); + TEST_DO(assertOptimized("reduce(m02_x3y3*m01_x3y3,sum)")); } TEST("require that result must be double to trigger optimization") { @@ -201,14 +163,6 @@ TEST("require that result must be double to trigger optimization") { TEST_DO(assertNotOptimized("reduce(m01_x3y3*m01_x3y3,sum,y)")); } -TEST("require that additional dimensions must have matching size") { - TEST_DO(assertOptimized("reduce(m01_x3y3*m01_x3y3,sum)")); - TEST_DO(assertNotOptimized("reduce(m01_x3y3*m03_x3y2,sum)")); - TEST_DO(assertNotOptimized("reduce(m03_x3y2*m01_x3y3,sum)")); - TEST_DO(assertNotOptimized("reduce(m01_x3y3*m05_x3yu,sum)")); - TEST_DO(assertNotOptimized("reduce(m05_x3yu*m01_x3y3,sum)")); -} - void verify_compatible(const vespalib::string &a, const vespalib::string &b) { auto a_type = ValueType::from_spec(a); auto b_type = ValueType::from_spec(b); @@ -231,13 +185,9 @@ TEST("require that type compatibility test is appropriate") { TEST_DO(verify_compatible("tensor(x[5])", "tensor(x[5])")); TEST_DO(verify_not_compatible("tensor(x[5])", "tensor(y[5])")); TEST_DO(verify_compatible("tensor(x[5])", "tensor(x[3])")); - TEST_DO(verify_compatible("tensor(x[])", "tensor(x[3])")); TEST_DO(verify_compatible("tensor(x[3],y[7],z[9])", "tensor(x[5],y[7],z[9])")); - TEST_DO(verify_compatible("tensor(x[3],y[7],z[9])", "tensor(x[],y[7],z[9])")); TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[9])", "tensor(x[5],y[5],z[9])")); - TEST_DO(verify_not_compatible("tensor(x[5],y[],z[9])", "tensor(x[5],y[7],z[9])")); TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[9])", "tensor(x[5],y[7],z[5])")); - TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[])", "tensor(x[5],y[7],z[9])")); } //----------------------------------------------------------------------------- |