summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-01-28 14:42:19 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-01-28 14:48:02 +0000
commitab5fbb57a43485ea908fb98dea13f9d5cf3d13f6 (patch)
treec2dabdd323bd5761195e6ca3e206ca9312985f3e /eval
parentd0e9532d3178ed0c4631d32e89ccf9faeef084a1 (diff)
better coverage of cell type and parameter ordering
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp116
1 files changed, 69 insertions, 47 deletions
diff --git a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
index 426281686d7..36609c04219 100644
--- a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
@@ -1,8 +1,5 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-#include <vespa/log/log.h>
-LOG_SETUP("dense_dot_product_function_test");
-
#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/eval/eval/tensor_function.h>
#include <vespa/eval/eval/operation.h>
@@ -26,6 +23,12 @@ using namespace vespalib::eval::tensor_function;
const TensorEngine &prod_engine = DefaultTensorEngine::ref();
+struct First {
+ bool value;
+ explicit First(bool value_in) : value(value_in) {}
+ operator bool() const { return value; }
+};
+
struct MyVecSeq : Sequence {
double operator[](size_t i) const override { return (3.0 + i) * 7.0; }
};
@@ -34,30 +37,44 @@ struct MyMatSeq : Sequence {
double operator[](size_t i) const override { return (5.0 + i) * 43.0; }
};
+void add_vector(EvalFixture::ParamRepo &repo, const char *d1, size_t s1) {
+ auto name = make_string("%s%zu", d1, s1);
+ auto layout = Layout({{d1, s1}});
+ repo.add(name, spec(layout, MyVecSeq()));
+ repo.add(name + "f", spec(float_cells(layout), MyVecSeq()));
+}
+
+void add_matrix(EvalFixture::ParamRepo &repo, const char *d1, size_t s1, const char *d2, size_t s2) {
+ auto name = make_string("%s%zu%s%zu", d1, s1, d2, s2);
+ auto layout = Layout({{d1, s1}, {d2, s2}});
+ repo.add(name, spec(layout, MyMatSeq()));
+ repo.add(name + "f", spec(float_cells(layout), MyMatSeq()));
+}
+
EvalFixture::ParamRepo make_params() {
- return EvalFixture::ParamRepo()
- .add("y1", spec({y(1)}, MyVecSeq()))
- .add("y3", spec({y(3)}, MyVecSeq()))
- .add("y3f", spec(float_cells({y(3)}), MyVecSeq()))
- .add("y5", spec({y(5)}, MyVecSeq()))
- .add("y16", spec({y(16)}, MyVecSeq()))
- .add("x1y1", spec({x(1),y(1)}, MyMatSeq()))
- .add("y1z1", spec({y(1),z(1)}, MyMatSeq()))
- .add("x2y3", spec({x(2),y(3)}, MyMatSeq()))
- .add("x2y3f", spec(float_cells({x(2),y(3)}), MyMatSeq()))
- .add("y3z2f", spec(float_cells({y(3),z(2)}), MyMatSeq()))
- .add("x2z3", spec({x(2),z(3)}, MyMatSeq()))
- .add("y3z2", spec({y(3),z(2)}, MyMatSeq()))
- .add("x8y5", spec({x(8),y(5)}, MyMatSeq()))
- .add("y5z8", spec({y(5),z(8)}, MyMatSeq()))
- .add("x5y16", spec({x(5),y(16)}, MyMatSeq()))
- .add("y16z5", spec({y(16),z(5)}, MyMatSeq()));
+ EvalFixture::ParamRepo repo;
+ add_vector(repo, "y", 1);
+ add_vector(repo, "y", 3);
+ add_vector(repo, "y", 5);
+ add_vector(repo, "y", 16);
+ add_matrix(repo, "x", 1, "y", 1);
+ add_matrix(repo, "y", 1, "z", 1);
+ add_matrix(repo, "x", 2, "y", 3);
+ add_matrix(repo, "y", 3, "z", 2);
+ add_matrix(repo, "x", 2, "z", 3);
+ add_matrix(repo, "x", 8, "y", 5);
+ add_matrix(repo, "y", 5, "z", 8);
+ add_matrix(repo, "x", 5, "y", 16);
+ add_matrix(repo, "y", 16, "z", 5);
+ return repo;
}
EvalFixture::ParamRepo param_repo = make_params();
void verify_optimized(const vespalib::string &expr, size_t vec_size, size_t res_size, bool happy) {
+ EvalFixture slow_fixture(prod_engine, expr, param_repo, false);
EvalFixture fixture(prod_engine, expr, param_repo, true);
EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQUAL(fixture.result(), slow_fixture.result());
auto info = fixture.find_all<DenseXWProductFunction>();
ASSERT_EQUAL(info.size(), 1u);
EXPECT_TRUE(info[0]->result_is_mutable());
@@ -66,33 +83,54 @@ void verify_optimized(const vespalib::string &expr, size_t vec_size, size_t res_
EXPECT_EQUAL(info[0]->matrixHasCommonDimensionInnermost(), happy);
}
+vespalib::string make_expr(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common,
+ bool float_a, bool float_b)
+{
+ return make_string("reduce(%s%s*%s%s,sum,%s)", a.c_str(), float_a ? "f" : "", b.c_str(), float_b ? "f" : "", common.c_str());
+}
+
+void verify_optimized_multi(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common,
+ size_t vec_size, size_t res_size, bool happy, First first = First(true))
+{
+ for (bool float_a: {false, true}) {
+ for (bool float_b: {false, true}) {
+ auto expr = make_expr(a, b, common, float_a, float_b);
+ TEST_STATE(expr.c_str());
+ TEST_DO(verify_optimized(expr, vec_size, res_size, happy));
+ }
+ }
+ if (first) {
+ TEST_DO(verify_optimized_multi(b, a, common, vec_size, res_size, happy, First(false)));
+ }
+}
+
void verify_not_optimized(const vespalib::string &expr) {
+ EvalFixture slow_fixture(prod_engine, expr, param_repo, false);
EvalFixture fixture(prod_engine, expr, param_repo, true);
EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQUAL(fixture.result(), slow_fixture.result());
auto info = fixture.find_all<DenseXWProductFunction>();
EXPECT_TRUE(info.empty());
}
TEST("require that xw product gives same results as reference join/reduce") {
// 1 -> 1 happy/unhappy
- TEST_DO(verify_optimized("reduce(y1*x1y1,sum,y)", 1, 1, true));
- TEST_DO(verify_optimized("reduce(y1*y1z1,sum,y)", 1, 1, false));
+ TEST_DO(verify_optimized_multi("y1", "x1y1", "y", 1, 1, true));
+ TEST_DO(verify_optimized_multi("y1", "y1z1", "y", 1, 1, false));
// 3 -> 2 happy/unhappy
- TEST_DO(verify_optimized("reduce(y3*x2y3,sum,y)", 3, 2, true));
- TEST_DO(verify_optimized("reduce(y3*y3z2,sum,y)", 3, 2, false));
+ TEST_DO(verify_optimized_multi("y3", "x2y3", "y", 3, 2, true));
+ TEST_DO(verify_optimized_multi("y3", "y3z2", "y", 3, 2, false));
// 5 -> 8 happy/unhappy
- TEST_DO(verify_optimized("reduce(y5*x8y5,sum,y)", 5, 8, true));
- TEST_DO(verify_optimized("reduce(y5*y5z8,sum,y)", 5, 8, false));
+ TEST_DO(verify_optimized_multi("y5", "x8y5", "y", 5, 8, true));
+ TEST_DO(verify_optimized_multi("y5", "y5z8", "y", 5, 8, false));
// 16 -> 5 happy/unhappy
- TEST_DO(verify_optimized("reduce(y16*x5y16,sum,y)", 16, 5, true));
- TEST_DO(verify_optimized("reduce(y16*y16z5,sum,y)", 16, 5, false));
+ TEST_DO(verify_optimized_multi("y16", "x5y16", "y", 16, 5, true));
+ TEST_DO(verify_optimized_multi("y16", "y16z5", "y", 16, 5, false));
}
TEST("require that various variants of xw product can be optimized") {
- TEST_DO(verify_optimized("reduce(y3*x2y3,sum,y)", 3, 2, true));
- TEST_DO(verify_optimized("reduce(x2y3*y3,sum,y)", 3, 2, true));
TEST_DO(verify_optimized("reduce(join(y3,x2y3,f(x,y)(x*y)),sum,y)", 3, 2, true));
- TEST_DO(verify_optimized("reduce(join(x2y3,y3,f(x,y)(x*y)),sum,y)", 3, 2, true));
+ TEST_DO(verify_optimized("reduce(join(y3,x2y3,f(x,y)(y*x)),sum,y)", 3, 2, true));
}
TEST("require that expressions similar to xw product are not optimized") {
@@ -100,13 +138,9 @@ TEST("require that expressions similar to xw product are not optimized") {
TEST_DO(verify_not_optimized("reduce(y3*x2y3,prod,y)"));
TEST_DO(verify_not_optimized("reduce(y3*x2y3,sum)"));
TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(x+y)),sum,y)"));
- // TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*x)),sum,y)"));
TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(x*x)),sum,y)"));
TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*y)),sum,y)"));
TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*x*1)),sum,y)"));
-}
-
-TEST("require that xw products with incompatible dimensions are not optimized") {
TEST_DO(verify_not_optimized("reduce(y3*x2z3,sum,y)"));
TEST_DO(verify_not_optimized("reduce(y3*x2z3,sum,z)"));
}
@@ -119,16 +153,4 @@ TEST("require that xw product can be debug dumped") {
fprintf(stderr, "%s\n", info[0]->as_string().c_str());
}
-TEST("require that optimization works for float cells") {
- TEST_DO(verify_optimized("reduce(y3f*x2y3,sum,y)", 3, 2, true));
- TEST_DO(verify_optimized("reduce(y3*x2y3f,sum,y)", 3, 2, true));
- TEST_DO(verify_optimized("reduce(y3f*x2y3f,sum,y)", 3, 2, true));
-}
-
-TEST("require that optimization works for float cells with inconvenient dimension nesting") {
- TEST_DO(verify_optimized("reduce(y3f*y3z2,sum,y)", 3, 2, false));
- TEST_DO(verify_optimized("reduce(y3*y3z2f,sum,y)", 3, 2, false));
- TEST_DO(verify_optimized("reduce(y3f*y3z2f,sum,y)", 3, 2, false));
-}
-
TEST_MAIN() { TEST_RUN_ALL(); }