summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2018-02-08 15:32:31 +0000
committerHåvard Pettersen <havardpe@oath.com>2018-02-08 15:46:56 +0000
commit735051eeef7ca1296333f58e29c3887622385b2e (patch)
tree988148bce8a096dd41c12bec8db9b45ac019cc3f /eval
parent3e79340de4405f58ead08845bc07fbd50e9253ae (diff)
extend xw product test
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp159
1 files changed, 86 insertions, 73 deletions
diff --git a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
index ffd0f17be75..01abad343ae 100644
--- a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
@@ -13,101 +13,114 @@ LOG_SETUP("dense_dot_product_function_test");
#include <vespa/eval/tensor/dense/dense_tensor.h>
#include <vespa/eval/tensor/dense/dense_tensor_builder.h>
#include <vespa/eval/tensor/dense/dense_tensor_view.h>
+#include <vespa/eval/eval/test/tensor_model.hpp>
+#include <vespa/eval/eval/test/eval_fixture.h>
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/util/stash.h>
using namespace vespalib;
using namespace vespalib::eval;
+using namespace vespalib::eval::test;
using namespace vespalib::tensor;
using namespace vespalib::eval::tensor_function;
-const TensorEngine &ref_engine = SimpleTensorEngine::ref();
const TensorEngine &prod_engine = DefaultTensorEngine::ref();
-void verify_equal(const Value &expect, const Value &value) {
- const eval::Tensor *tensor = value.as_tensor();
- ASSERT_TRUE(tensor != nullptr);
- const eval::Tensor *expect_tensor = expect.as_tensor();
- ASSERT_TRUE(expect_tensor != nullptr);
- auto expect_spec = expect_tensor->engine().to_spec(expect);
- auto value_spec = tensor->engine().to_spec(value);
- EXPECT_EQUAL(expect_spec, value_spec);
+struct MyVecSeq : Sequence {
+ double operator[](size_t i) const override { return (3.0 + i) * 7.0; }
+};
+
+struct MyMatSeq : Sequence {
+ double operator[](size_t i) const override { return (5.0 + i) * 43.0; }
+};
+
+EvalFixture::ParamRepo make_params() {
+ return EvalFixture::ParamRepo()
+ .add("y1", spec({y(1)}, MyVecSeq()))
+ .add("y3", spec({y(3)}, MyVecSeq()))
+ .add("y5", spec({y(5)}, MyVecSeq()))
+ .add("y16", spec({y(16)}, MyVecSeq()))
+ .add("x1y1", spec({x(1),y(1)}, MyMatSeq()))
+ .add("y1z1", spec({y(1),z(1)}, MyMatSeq()))
+ .add("x2y3", spec({x(2),y(3)}, MyMatSeq()))
+ .add("x2z3", spec({x(2),z(3)}, MyMatSeq()))
+ .add("y3z2", spec({y(3),z(2)}, MyMatSeq()))
+ .add("x8y5", spec({x(8),y(5)}, MyMatSeq()))
+ .add("y5z8", spec({y(5),z(8)}, MyMatSeq()))
+ .add("x5y16", spec({x(5),y(16)}, MyMatSeq()))
+ .add("y16z5", spec({y(16),z(5)}, MyMatSeq()))
+ .add("a_y3", spec({y(3)}, MyVecSeq()), "any")
+ .add("y3_u", spec({y(3)}, MyVecSeq()), "tensor(y[])")
+ .add("a_x2y3", spec({x(2),y(3)}, MyMatSeq()), "any")
+ .add("x2_uy3", spec({x(2),y(3)}, MyMatSeq()), "tensor(x[],y[3])")
+ .add("x2y3_u", spec({x(2),y(3)}, MyMatSeq()), "tensor(x[2],y[])");
}
-
-SimpleObjectParams wrap(std::vector<eval::Value::CREF> params) {
- return SimpleObjectParams(params);
+EvalFixture::ParamRepo param_repo = make_params();
+
+void verify_optimized(const vespalib::string &expr, size_t vec_size, size_t res_size, bool happy) {
+ EvalFixture fixture(prod_engine, expr, param_repo, true);
+ EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
+ auto info = fixture.find_all<DenseXWProductFunction>();
+ ASSERT_EQUAL(info.size(), 1u);
+ EXPECT_EQUAL(info[0]->vectorSize(), vec_size);
+ EXPECT_EQUAL(info[0]->resultSize(), res_size);
+ EXPECT_EQUAL(info[0]->matrixHasCommonDimensionInnermost(), happy);
}
-void verify_result(const TensorSpec &v, const TensorSpec &m, bool happy) {
- Stash stash;
- Value::UP ref_vec = ref_engine.from_spec(v);
- Value::UP ref_mat = ref_engine.from_spec(m);
- const Value &joined = ref_engine.join(*ref_vec, *ref_mat, operation::Mul::f, stash);
- const Value &expect = ref_engine.reduce(joined, Aggr::SUM, {"x"}, stash);
-
- Value::UP prod_vec = prod_engine.from_spec(v);
- Value::UP prod_mat = prod_engine.from_spec(m);
-
- Inject vec_first(prod_vec->type(), 0);
- Inject mat_last(prod_mat->type(), 1);
-
- DenseXWProductFunction fun1(expect.type(), vec_first, mat_last,
- prod_vec->type().dimensions()[0].size,
- expect.type().dimensions()[0].size,
- happy);
- InterpretedFunction ifun1(prod_engine, fun1);
- InterpretedFunction::Context ictx1(ifun1);
- const Value &actual1 = ifun1.eval(ictx1, wrap({*prod_vec, *prod_mat}));
- TEST_DO(verify_equal(expect, actual1));
-
- Inject vec_last(prod_vec->type(), 1);
- Inject mat_first(prod_mat->type(), 0);
-
- DenseXWProductFunction fun2(expect.type(), vec_last, mat_first,
- prod_vec->type().dimensions()[0].size,
- expect.type().dimensions()[0].size,
- happy);
- InterpretedFunction ifun2(prod_engine, fun2);
- InterpretedFunction::Context ictx2(ifun2);
- const Value &actual2 = ifun2.eval(ictx2, wrap({*prod_mat, *prod_vec}));
- TEST_DO(verify_equal(expect, actual2));
+void verify_not_optimized(const vespalib::string &expr) {
+ EvalFixture fixture(prod_engine, expr, param_repo, true);
+ EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
+ auto info = fixture.find_all<DenseXWProductFunction>();
+ EXPECT_TRUE(info.empty());
}
-TensorSpec make_vector(const vespalib::string &name, size_t sz) {
- TensorSpec ret(make_string("tensor(%s[%zu])", name.c_str(), sz));
- for (size_t i = 0; i < sz; ++i) {
- ret.add({{name, i}}, (1.0 + i) * 16.0);
- }
- return ret;
+TEST("require that xw product gives same results as reference join/reduce") {
+ // 1 -> 1 happy/unhappy
+ TEST_DO(verify_optimized("reduce(y1*x1y1,sum,y)", 1, 1, true));
+ TEST_DO(verify_optimized("reduce(y1*y1z1,sum,y)", 1, 1, false));
+ // 3 -> 2 happy/unhappy
+ TEST_DO(verify_optimized("reduce(y3*x2y3,sum,y)", 3, 2, true));
+ TEST_DO(verify_optimized("reduce(y3*y3z2,sum,y)", 3, 2, false));
+ // 5 -> 8 happy/unhappy
+ TEST_DO(verify_optimized("reduce(y5*x8y5,sum,y)", 5, 8, true));
+ TEST_DO(verify_optimized("reduce(y5*y5z8,sum,y)", 5, 8, false));
+ // 16 -> 5 happy/unhappy
+ TEST_DO(verify_optimized("reduce(y16*x5y16,sum,y)", 16, 5, true));
+ TEST_DO(verify_optimized("reduce(y16*y16z5,sum,y)", 16, 5, false));
}
-TensorSpec make_matrix(const vespalib::string &d1name, size_t d1sz,
- const vespalib::string &d2name, size_t d2sz)
-{
- TensorSpec ret(make_string("tensor(%s[%zu],%s[%zu])",
- d1name.c_str(), d1sz,
- d2name.c_str(), d2sz));
- for (size_t i = 0; i < d1sz; ++i) {
- for (size_t j = 0; j < d2sz; ++j) {
- ret.add({{d1name,i},{d2name,j}}, 1.0 + i*7.0 + j*43.0);
- }
- }
- return ret;
+TEST("require that xw product is not optimized for abstract types") {
+ TEST_DO(verify_not_optimized("reduce(a_y3*x2y3,sum)"));
+ TEST_DO(verify_not_optimized("reduce(y3*a_x2y3,sum)"));
+ TEST_DO(verify_not_optimized("reduce(y3_u*x2y3,sum)"));
+ TEST_DO(verify_not_optimized("reduce(y3*x2_uy3,sum)"));
+ TEST_DO(verify_not_optimized("reduce(y3*x2y3_u,sum)"));
}
-TEST("require that xw product gives same results as reference join/reduce") {
- verify_result(make_vector("x", 1), make_matrix("o", 1, "x", 1), true);
- verify_result(make_vector("x", 1), make_matrix("x", 1, "y", 1), false);
-
- verify_result(make_vector("x", 3), make_matrix("o", 2, "x", 3), true);
- verify_result(make_vector("x", 3), make_matrix("x", 3, "y", 2), false);
+TEST("require that various variants of xw product can be optimized") {
+ TEST_DO(verify_optimized("reduce(y3*x2y3,sum,y)", 3, 2, true));
+ TEST_DO(verify_optimized("reduce(x2y3*y3,sum,y)", 3, 2, true));
+ TEST_DO(verify_optimized("reduce(join(y3,x2y3,f(x,y)(x*y)),sum,y)", 3, 2, true));
+ TEST_DO(verify_optimized("reduce(join(x2y3,y3,f(x,y)(x*y)),sum,y)", 3, 2, true));
+}
- verify_result(make_vector("x", 5), make_matrix("o", 8, "x", 5), true);
- verify_result(make_vector("x", 5), make_matrix("x", 5, "y", 8), false);
+TEST("require that expressions similar to xw product are not optimized") {
+ TEST_DO(verify_not_optimized("reduce(y3*x2y3,sum,x)"));
+ TEST_DO(verify_not_optimized("reduce(y3*x2y3,prod,y)"));
+ TEST_DO(verify_not_optimized("reduce(y3*x2y3,sum)"));
+ TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(x+y)),sum,y)"));
+ // TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*x)),sum,y)"));
+ TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(x*x)),sum,y)"));
+ TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*y)),sum,y)"));
+ TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*x*1)),sum,y)"));
+}
- verify_result(make_vector("x", 16), make_matrix("o", 5, "x", 16), true);
- verify_result(make_vector("x", 16), make_matrix("x", 16, "y", 5), false);
+TEST("require that xw products with incompatible dimensions are not optimized") {
+ TEST_DO(verify_not_optimized("reduce(y3*x1y1,sum,y)"));
+ TEST_DO(verify_not_optimized("reduce(y3*x8y5,sum,y)"));
+ TEST_DO(verify_not_optimized("reduce(y3*x2z3,sum,y)"));
+ TEST_DO(verify_not_optimized("reduce(y3*x2z3,sum,z)"));
}
TEST_MAIN() { TEST_RUN_ALL(); }