From 735051eeef7ca1296333f58e29c3887622385b2e Mon Sep 17 00:00:00 2001 From: HÃ¥vard Pettersen Date: Thu, 8 Feb 2018 15:32:31 +0000 Subject: extend xw product test --- .../dense_xw_product_function_test.cpp | 159 +++++++++++---------- 1 file changed, 86 insertions(+), 73 deletions(-) (limited to 'eval/src/tests') diff --git a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp index ffd0f17be75..01abad343ae 100644 --- a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp +++ b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp @@ -13,101 +13,114 @@ LOG_SETUP("dense_dot_product_function_test"); #include #include #include +#include +#include #include #include using namespace vespalib; using namespace vespalib::eval; +using namespace vespalib::eval::test; using namespace vespalib::tensor; using namespace vespalib::eval::tensor_function; -const TensorEngine &ref_engine = SimpleTensorEngine::ref(); const TensorEngine &prod_engine = DefaultTensorEngine::ref(); -void verify_equal(const Value &expect, const Value &value) { - const eval::Tensor *tensor = value.as_tensor(); - ASSERT_TRUE(tensor != nullptr); - const eval::Tensor *expect_tensor = expect.as_tensor(); - ASSERT_TRUE(expect_tensor != nullptr); - auto expect_spec = expect_tensor->engine().to_spec(expect); - auto value_spec = tensor->engine().to_spec(value); - EXPECT_EQUAL(expect_spec, value_spec); +struct MyVecSeq : Sequence { + double operator[](size_t i) const override { return (3.0 + i) * 7.0; } +}; + +struct MyMatSeq : Sequence { + double operator[](size_t i) const override { return (5.0 + i) * 43.0; } +}; + +EvalFixture::ParamRepo make_params() { + return EvalFixture::ParamRepo() + .add("y1", spec({y(1)}, MyVecSeq())) + .add("y3", spec({y(3)}, MyVecSeq())) + .add("y5", spec({y(5)}, MyVecSeq())) + .add("y16", spec({y(16)}, MyVecSeq())) + .add("x1y1", spec({x(1),y(1)}, MyMatSeq())) + .add("y1z1", spec({y(1),z(1)}, MyMatSeq())) + .add("x2y3", spec({x(2),y(3)}, MyMatSeq())) + .add("x2z3", spec({x(2),z(3)}, MyMatSeq())) + .add("y3z2", spec({y(3),z(2)}, MyMatSeq())) + .add("x8y5", spec({x(8),y(5)}, MyMatSeq())) + .add("y5z8", spec({y(5),z(8)}, MyMatSeq())) + .add("x5y16", spec({x(5),y(16)}, MyMatSeq())) + .add("y16z5", spec({y(16),z(5)}, MyMatSeq())) + .add("a_y3", spec({y(3)}, MyVecSeq()), "any") + .add("y3_u", spec({y(3)}, MyVecSeq()), "tensor(y[])") + .add("a_x2y3", spec({x(2),y(3)}, MyMatSeq()), "any") + .add("x2_uy3", spec({x(2),y(3)}, MyMatSeq()), "tensor(x[],y[3])") + .add("x2y3_u", spec({x(2),y(3)}, MyMatSeq()), "tensor(x[2],y[])"); } - -SimpleObjectParams wrap(std::vector params) { - return SimpleObjectParams(params); +EvalFixture::ParamRepo param_repo = make_params(); + +void verify_optimized(const vespalib::string &expr, size_t vec_size, size_t res_size, bool happy) { + EvalFixture fixture(prod_engine, expr, param_repo, true); + EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo)); + auto info = fixture.find_all(); + ASSERT_EQUAL(info.size(), 1u); + EXPECT_EQUAL(info[0]->vectorSize(), vec_size); + EXPECT_EQUAL(info[0]->resultSize(), res_size); + EXPECT_EQUAL(info[0]->matrixHasCommonDimensionInnermost(), happy); } -void verify_result(const TensorSpec &v, const TensorSpec &m, bool happy) { - Stash stash; - Value::UP ref_vec = ref_engine.from_spec(v); - Value::UP ref_mat = ref_engine.from_spec(m); - const Value &joined = ref_engine.join(*ref_vec, *ref_mat, operation::Mul::f, stash); - const Value &expect = ref_engine.reduce(joined, Aggr::SUM, {"x"}, stash); - - Value::UP prod_vec = prod_engine.from_spec(v); - Value::UP prod_mat = prod_engine.from_spec(m); - - Inject vec_first(prod_vec->type(), 0); - Inject mat_last(prod_mat->type(), 1); - - DenseXWProductFunction fun1(expect.type(), vec_first, mat_last, - prod_vec->type().dimensions()[0].size, - expect.type().dimensions()[0].size, - happy); - InterpretedFunction ifun1(prod_engine, fun1); - InterpretedFunction::Context ictx1(ifun1); - const Value &actual1 = ifun1.eval(ictx1, wrap({*prod_vec, *prod_mat})); - TEST_DO(verify_equal(expect, actual1)); - - Inject vec_last(prod_vec->type(), 1); - Inject mat_first(prod_mat->type(), 0); - - DenseXWProductFunction fun2(expect.type(), vec_last, mat_first, - prod_vec->type().dimensions()[0].size, - expect.type().dimensions()[0].size, - happy); - InterpretedFunction ifun2(prod_engine, fun2); - InterpretedFunction::Context ictx2(ifun2); - const Value &actual2 = ifun2.eval(ictx2, wrap({*prod_mat, *prod_vec})); - TEST_DO(verify_equal(expect, actual2)); +void verify_not_optimized(const vespalib::string &expr) { + EvalFixture fixture(prod_engine, expr, param_repo, true); + EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo)); + auto info = fixture.find_all(); + EXPECT_TRUE(info.empty()); } -TensorSpec make_vector(const vespalib::string &name, size_t sz) { - TensorSpec ret(make_string("tensor(%s[%zu])", name.c_str(), sz)); - for (size_t i = 0; i < sz; ++i) { - ret.add({{name, i}}, (1.0 + i) * 16.0); - } - return ret; +TEST("require that xw product gives same results as reference join/reduce") { + // 1 -> 1 happy/unhappy + TEST_DO(verify_optimized("reduce(y1*x1y1,sum,y)", 1, 1, true)); + TEST_DO(verify_optimized("reduce(y1*y1z1,sum,y)", 1, 1, false)); + // 3 -> 2 happy/unhappy + TEST_DO(verify_optimized("reduce(y3*x2y3,sum,y)", 3, 2, true)); + TEST_DO(verify_optimized("reduce(y3*y3z2,sum,y)", 3, 2, false)); + // 5 -> 8 happy/unhappy + TEST_DO(verify_optimized("reduce(y5*x8y5,sum,y)", 5, 8, true)); + TEST_DO(verify_optimized("reduce(y5*y5z8,sum,y)", 5, 8, false)); + // 16 -> 5 happy/unhappy + TEST_DO(verify_optimized("reduce(y16*x5y16,sum,y)", 16, 5, true)); + TEST_DO(verify_optimized("reduce(y16*y16z5,sum,y)", 16, 5, false)); } -TensorSpec make_matrix(const vespalib::string &d1name, size_t d1sz, - const vespalib::string &d2name, size_t d2sz) -{ - TensorSpec ret(make_string("tensor(%s[%zu],%s[%zu])", - d1name.c_str(), d1sz, - d2name.c_str(), d2sz)); - for (size_t i = 0; i < d1sz; ++i) { - for (size_t j = 0; j < d2sz; ++j) { - ret.add({{d1name,i},{d2name,j}}, 1.0 + i*7.0 + j*43.0); - } - } - return ret; +TEST("require that xw product is not optimized for abstract types") { + TEST_DO(verify_not_optimized("reduce(a_y3*x2y3,sum)")); + TEST_DO(verify_not_optimized("reduce(y3*a_x2y3,sum)")); + TEST_DO(verify_not_optimized("reduce(y3_u*x2y3,sum)")); + TEST_DO(verify_not_optimized("reduce(y3*x2_uy3,sum)")); + TEST_DO(verify_not_optimized("reduce(y3*x2y3_u,sum)")); } -TEST("require that xw product gives same results as reference join/reduce") { - verify_result(make_vector("x", 1), make_matrix("o", 1, "x", 1), true); - verify_result(make_vector("x", 1), make_matrix("x", 1, "y", 1), false); - - verify_result(make_vector("x", 3), make_matrix("o", 2, "x", 3), true); - verify_result(make_vector("x", 3), make_matrix("x", 3, "y", 2), false); +TEST("require that various variants of xw product can be optimized") { + TEST_DO(verify_optimized("reduce(y3*x2y3,sum,y)", 3, 2, true)); + TEST_DO(verify_optimized("reduce(x2y3*y3,sum,y)", 3, 2, true)); + TEST_DO(verify_optimized("reduce(join(y3,x2y3,f(x,y)(x*y)),sum,y)", 3, 2, true)); + TEST_DO(verify_optimized("reduce(join(x2y3,y3,f(x,y)(x*y)),sum,y)", 3, 2, true)); +} - verify_result(make_vector("x", 5), make_matrix("o", 8, "x", 5), true); - verify_result(make_vector("x", 5), make_matrix("x", 5, "y", 8), false); +TEST("require that expressions similar to xw product are not optimized") { + TEST_DO(verify_not_optimized("reduce(y3*x2y3,sum,x)")); + TEST_DO(verify_not_optimized("reduce(y3*x2y3,prod,y)")); + TEST_DO(verify_not_optimized("reduce(y3*x2y3,sum)")); + TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(x+y)),sum,y)")); + // TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*x)),sum,y)")); + TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(x*x)),sum,y)")); + TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*y)),sum,y)")); + TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*x*1)),sum,y)")); +} - verify_result(make_vector("x", 16), make_matrix("o", 5, "x", 16), true); - verify_result(make_vector("x", 16), make_matrix("x", 16, "y", 5), false); +TEST("require that xw products with incompatible dimensions are not optimized") { + TEST_DO(verify_not_optimized("reduce(y3*x1y1,sum,y)")); + TEST_DO(verify_not_optimized("reduce(y3*x8y5,sum,y)")); + TEST_DO(verify_not_optimized("reduce(y3*x2z3,sum,y)")); + TEST_DO(verify_not_optimized("reduce(y3*x2z3,sum,z)")); } TEST_MAIN() { TEST_RUN_ALL(); } -- cgit v1.2.3