From ab5fbb57a43485ea908fb98dea13f9d5cf3d13f6 Mon Sep 17 00:00:00 2001 From: HÃ¥vard Pettersen Date: Tue, 28 Jan 2020 14:42:19 +0000 Subject: better coverage of cell type and parameter ordering --- .../dense_xw_product_function_test.cpp | 116 ++++++++++++--------- 1 file changed, 69 insertions(+), 47 deletions(-) diff --git a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp index 426281686d7..36609c04219 100644 --- a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp +++ b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp @@ -1,8 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include -LOG_SETUP("dense_dot_product_function_test"); - #include #include #include @@ -26,6 +23,12 @@ using namespace vespalib::eval::tensor_function; const TensorEngine &prod_engine = DefaultTensorEngine::ref(); +struct First { + bool value; + explicit First(bool value_in) : value(value_in) {} + operator bool() const { return value; } +}; + struct MyVecSeq : Sequence { double operator[](size_t i) const override { return (3.0 + i) * 7.0; } }; @@ -34,30 +37,44 @@ struct MyMatSeq : Sequence { double operator[](size_t i) const override { return (5.0 + i) * 43.0; } }; +void add_vector(EvalFixture::ParamRepo &repo, const char *d1, size_t s1) { + auto name = make_string("%s%zu", d1, s1); + auto layout = Layout({{d1, s1}}); + repo.add(name, spec(layout, MyVecSeq())); + repo.add(name + "f", spec(float_cells(layout), MyVecSeq())); +} + +void add_matrix(EvalFixture::ParamRepo &repo, const char *d1, size_t s1, const char *d2, size_t s2) { + auto name = make_string("%s%zu%s%zu", d1, s1, d2, s2); + auto layout = Layout({{d1, s1}, {d2, s2}}); + repo.add(name, spec(layout, MyMatSeq())); + repo.add(name + "f", spec(float_cells(layout), MyMatSeq())); +} + EvalFixture::ParamRepo make_params() { - return EvalFixture::ParamRepo() - .add("y1", spec({y(1)}, MyVecSeq())) - .add("y3", spec({y(3)}, MyVecSeq())) - .add("y3f", spec(float_cells({y(3)}), MyVecSeq())) - .add("y5", spec({y(5)}, MyVecSeq())) - .add("y16", spec({y(16)}, MyVecSeq())) - .add("x1y1", spec({x(1),y(1)}, MyMatSeq())) - .add("y1z1", spec({y(1),z(1)}, MyMatSeq())) - .add("x2y3", spec({x(2),y(3)}, MyMatSeq())) - .add("x2y3f", spec(float_cells({x(2),y(3)}), MyMatSeq())) - .add("y3z2f", spec(float_cells({y(3),z(2)}), MyMatSeq())) - .add("x2z3", spec({x(2),z(3)}, MyMatSeq())) - .add("y3z2", spec({y(3),z(2)}, MyMatSeq())) - .add("x8y5", spec({x(8),y(5)}, MyMatSeq())) - .add("y5z8", spec({y(5),z(8)}, MyMatSeq())) - .add("x5y16", spec({x(5),y(16)}, MyMatSeq())) - .add("y16z5", spec({y(16),z(5)}, MyMatSeq())); + EvalFixture::ParamRepo repo; + add_vector(repo, "y", 1); + add_vector(repo, "y", 3); + add_vector(repo, "y", 5); + add_vector(repo, "y", 16); + add_matrix(repo, "x", 1, "y", 1); + add_matrix(repo, "y", 1, "z", 1); + add_matrix(repo, "x", 2, "y", 3); + add_matrix(repo, "y", 3, "z", 2); + add_matrix(repo, "x", 2, "z", 3); + add_matrix(repo, "x", 8, "y", 5); + add_matrix(repo, "y", 5, "z", 8); + add_matrix(repo, "x", 5, "y", 16); + add_matrix(repo, "y", 16, "z", 5); + return repo; } EvalFixture::ParamRepo param_repo = make_params(); void verify_optimized(const vespalib::string &expr, size_t vec_size, size_t res_size, bool happy) { + EvalFixture slow_fixture(prod_engine, expr, param_repo, false); EvalFixture fixture(prod_engine, expr, param_repo, true); EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo)); + EXPECT_EQUAL(fixture.result(), slow_fixture.result()); auto info = fixture.find_all(); ASSERT_EQUAL(info.size(), 1u); EXPECT_TRUE(info[0]->result_is_mutable()); @@ -66,33 +83,54 @@ void verify_optimized(const vespalib::string &expr, size_t vec_size, size_t res_ EXPECT_EQUAL(info[0]->matrixHasCommonDimensionInnermost(), happy); } +vespalib::string make_expr(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common, + bool float_a, bool float_b) +{ + return make_string("reduce(%s%s*%s%s,sum,%s)", a.c_str(), float_a ? "f" : "", b.c_str(), float_b ? "f" : "", common.c_str()); +} + +void verify_optimized_multi(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common, + size_t vec_size, size_t res_size, bool happy, First first = First(true)) +{ + for (bool float_a: {false, true}) { + for (bool float_b: {false, true}) { + auto expr = make_expr(a, b, common, float_a, float_b); + TEST_STATE(expr.c_str()); + TEST_DO(verify_optimized(expr, vec_size, res_size, happy)); + } + } + if (first) { + TEST_DO(verify_optimized_multi(b, a, common, vec_size, res_size, happy, First(false))); + } +} + void verify_not_optimized(const vespalib::string &expr) { + EvalFixture slow_fixture(prod_engine, expr, param_repo, false); EvalFixture fixture(prod_engine, expr, param_repo, true); EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo)); + EXPECT_EQUAL(fixture.result(), slow_fixture.result()); auto info = fixture.find_all(); EXPECT_TRUE(info.empty()); } TEST("require that xw product gives same results as reference join/reduce") { // 1 -> 1 happy/unhappy - TEST_DO(verify_optimized("reduce(y1*x1y1,sum,y)", 1, 1, true)); - TEST_DO(verify_optimized("reduce(y1*y1z1,sum,y)", 1, 1, false)); + TEST_DO(verify_optimized_multi("y1", "x1y1", "y", 1, 1, true)); + TEST_DO(verify_optimized_multi("y1", "y1z1", "y", 1, 1, false)); // 3 -> 2 happy/unhappy - TEST_DO(verify_optimized("reduce(y3*x2y3,sum,y)", 3, 2, true)); - TEST_DO(verify_optimized("reduce(y3*y3z2,sum,y)", 3, 2, false)); + TEST_DO(verify_optimized_multi("y3", "x2y3", "y", 3, 2, true)); + TEST_DO(verify_optimized_multi("y3", "y3z2", "y", 3, 2, false)); // 5 -> 8 happy/unhappy - TEST_DO(verify_optimized("reduce(y5*x8y5,sum,y)", 5, 8, true)); - TEST_DO(verify_optimized("reduce(y5*y5z8,sum,y)", 5, 8, false)); + TEST_DO(verify_optimized_multi("y5", "x8y5", "y", 5, 8, true)); + TEST_DO(verify_optimized_multi("y5", "y5z8", "y", 5, 8, false)); // 16 -> 5 happy/unhappy - TEST_DO(verify_optimized("reduce(y16*x5y16,sum,y)", 16, 5, true)); - TEST_DO(verify_optimized("reduce(y16*y16z5,sum,y)", 16, 5, false)); + TEST_DO(verify_optimized_multi("y16", "x5y16", "y", 16, 5, true)); + TEST_DO(verify_optimized_multi("y16", "y16z5", "y", 16, 5, false)); } TEST("require that various variants of xw product can be optimized") { - TEST_DO(verify_optimized("reduce(y3*x2y3,sum,y)", 3, 2, true)); - TEST_DO(verify_optimized("reduce(x2y3*y3,sum,y)", 3, 2, true)); TEST_DO(verify_optimized("reduce(join(y3,x2y3,f(x,y)(x*y)),sum,y)", 3, 2, true)); - TEST_DO(verify_optimized("reduce(join(x2y3,y3,f(x,y)(x*y)),sum,y)", 3, 2, true)); + TEST_DO(verify_optimized("reduce(join(y3,x2y3,f(x,y)(y*x)),sum,y)", 3, 2, true)); } TEST("require that expressions similar to xw product are not optimized") { @@ -100,13 +138,9 @@ TEST("require that expressions similar to xw product are not optimized") { TEST_DO(verify_not_optimized("reduce(y3*x2y3,prod,y)")); TEST_DO(verify_not_optimized("reduce(y3*x2y3,sum)")); TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(x+y)),sum,y)")); - // TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*x)),sum,y)")); TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(x*x)),sum,y)")); TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*y)),sum,y)")); TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*x*1)),sum,y)")); -} - -TEST("require that xw products with incompatible dimensions are not optimized") { TEST_DO(verify_not_optimized("reduce(y3*x2z3,sum,y)")); TEST_DO(verify_not_optimized("reduce(y3*x2z3,sum,z)")); } @@ -119,16 +153,4 @@ TEST("require that xw product can be debug dumped") { fprintf(stderr, "%s\n", info[0]->as_string().c_str()); } -TEST("require that optimization works for float cells") { - TEST_DO(verify_optimized("reduce(y3f*x2y3,sum,y)", 3, 2, true)); - TEST_DO(verify_optimized("reduce(y3*x2y3f,sum,y)", 3, 2, true)); - TEST_DO(verify_optimized("reduce(y3f*x2y3f,sum,y)", 3, 2, true)); -} - -TEST("require that optimization works for float cells with inconvenient dimension nesting") { - TEST_DO(verify_optimized("reduce(y3f*y3z2,sum,y)", 3, 2, false)); - TEST_DO(verify_optimized("reduce(y3*y3z2f,sum,y)", 3, 2, false)); - TEST_DO(verify_optimized("reduce(y3f*y3z2f,sum,y)", 3, 2, false)); -} - TEST_MAIN() { TEST_RUN_ALL(); } -- cgit v1.2.3