summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2021-03-19 15:53:13 +0100
committerGitHub <noreply@github.com>2021-03-19 15:53:13 +0100
commiteb2ad14727e2c9491073bfa99b70dfcf754ca8e2 (patch)
treeaadaabf0684619080ab32f291eb9c201a51d4132 /eval
parent0960c9e8bcd7e7b336939db1f5ec1a2657175622 (diff)
parent591d6b9ee84860c49db089099fbde0b3e77a3ec7 (diff)
Merge pull request #17071 from vespa-engine/arnej/rewrite_mixed_simple_join_function_test
Arnej/rewrite mixed simple join function test
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/eval/gen_spec/gen_spec_test.cpp10
-rw-r--r--eval/src/tests/instruction/dense_inplace_join_function/dense_inplace_join_function_test.cpp82
-rw-r--r--eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp7
-rw-r--r--eval/src/tests/instruction/mixed_simple_join_function/mixed_simple_join_function_test.cpp264
-rw-r--r--eval/src/vespa/eval/eval/test/eval_fixture.cpp10
-rw-r--r--eval/src/vespa/eval/eval/test/eval_fixture.h23
-rw-r--r--eval/src/vespa/eval/eval/test/gen_spec.cpp1
-rw-r--r--eval/src/vespa/eval/eval/test/gen_spec.h1
-rw-r--r--eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp20
-rw-r--r--eval/src/vespa/eval/instruction/mixed_simple_join_function.h2
10 files changed, 244 insertions, 176 deletions
diff --git a/eval/src/tests/eval/gen_spec/gen_spec_test.cpp b/eval/src/tests/eval/gen_spec/gen_spec_test.cpp
index d89bf5b8e79..b371e08960f 100644
--- a/eval/src/tests/eval/gen_spec/gen_spec_test.cpp
+++ b/eval/src/tests/eval/gen_spec/gen_spec_test.cpp
@@ -245,6 +245,16 @@ TEST(GenSpecFromDescTest, dim_spec_and_gen_spec_can_be_created_from_desc) {
EXPECT_EQ(gen_desc, expect);
}
+TEST(GenSpecFromDescTest, empty_mapped_dim_possible) {
+ // 'a0_1'
+ auto expect = GenSpec().map("a", 0).gen();
+ auto dim_desc = GenSpec().desc("a0_1").gen();
+ auto gen_desc = GenSpec::from_desc("a0_1").gen();
+ EXPECT_EQ(dim_desc, expect);
+ EXPECT_EQ(gen_desc, expect);
+}
+
+
TEST(GenSpecFromDescTest, multi_character_sizes_work) {
// 'a13b1'
auto expect = GenSpec().idx("a", 13).idx("b", 1).gen();
diff --git a/eval/src/tests/instruction/dense_inplace_join_function/dense_inplace_join_function_test.cpp b/eval/src/tests/instruction/dense_inplace_join_function/dense_inplace_join_function_test.cpp
index 07e995d091f..53ac70c9a86 100644
--- a/eval/src/tests/instruction/dense_inplace_join_function/dense_inplace_join_function_test.cpp
+++ b/eval/src/tests/instruction/dense_inplace_join_function/dense_inplace_join_function_test.cpp
@@ -22,22 +22,21 @@ GenSpec::seq_t glb = [] (size_t) noexcept {
};
EvalFixture::ParamRepo make_params() {
- return EvalFixture::ParamRepo()
- .add("con_x5_A", GenSpec().idx("x", 5).seq(glb))
- .add("con_x5_B", GenSpec().idx("x", 5).seq(glb))
- .add("con_x5_C", GenSpec().idx("x", 5).seq(glb))
- .add("con_x5y3_A", GenSpec().idx("x", 5).idx("y", 3).seq(glb))
- .add("con_x5y3_B", GenSpec().idx("x", 5).idx("y", 3).seq(glb))
- .add_mutable("mut_dbl_A", GenSpec(1.5))
- .add_mutable("mut_dbl_B", GenSpec(2.5))
- .add_mutable("mut_x5_A", GenSpec().idx("x", 5).seq(glb))
- .add_mutable("mut_x5_B", GenSpec().idx("x", 5).seq(glb))
- .add_mutable("mut_x5_C", GenSpec().idx("x", 5).seq(glb))
- .add_mutable("mut_x5f_D", GenSpec().cells_float().idx("x", 5).seq(glb))
- .add_mutable("mut_x5f_E", GenSpec().cells_float().idx("x", 5).seq(glb))
- .add_mutable("mut_x5y3_A", GenSpec().idx("x", 5).idx("y", 3).seq(glb))
- .add_mutable("mut_x5y3_B", GenSpec().idx("x", 5).idx("y", 3).seq(glb))
- .add_mutable("mut_x_sparse", GenSpec().map("x", {"a", "b", "c"}).seq(glb));
+ EvalFixture::ParamRepo repo;
+ for (vespalib::string param : {
+ "x5$1", "x5$2", "x5$3",
+ "x5y3$1", "x5y3$2",
+ "@x5$1", "@x5$2", "@x5$3",
+ "@x5y3$1", "@x5y3$2",
+ "@x3_1$1", "@x3_1$2"
+ })
+ {
+ repo.add(param, CellType::DOUBLE, glb);
+ repo.add(param + "_f", CellType::FLOAT, glb);
+ }
+ repo.add_mutable("mut_dbl_A", GenSpec(1.5));
+ repo.add_mutable("mut_dbl_B", GenSpec(2.5));
+ return repo;
}
EvalFixture::ParamRepo param_repo = make_params();
@@ -47,9 +46,11 @@ void verify_optimized(const vespalib::string &expr, size_t param_idx) {
for (size_t i = 0; i < fixture.num_params(); ++i) {
TEST_STATE(vespalib::make_string("param %zu", i).c_str());
if (i == param_idx) {
- EXPECT_EQUAL(fixture.get_param(i), fixture.result());
+ EXPECT_EQUAL(fixture.param_value(i).cells().data,
+ fixture.result_value().cells().data);
} else {
- EXPECT_NOT_EQUAL(fixture.get_param(i), fixture.result());
+ EXPECT_NOT_EQUAL(fixture.param_value(i).cells().data,
+ fixture.result_value().cells().data);
}
}
}
@@ -70,41 +71,42 @@ void verify_not_optimized(const vespalib::string &expr) {
EvalFixture fixture(prod_factory, expr, param_repo, true, true);
EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
for (size_t i = 0; i < fixture.num_params(); ++i) {
- EXPECT_NOT_EQUAL(fixture.get_param(i), fixture.result());
+ EXPECT_NOT_EQUAL(fixture.param_value(i).cells().data,
+ fixture.result_value().cells().data);
}
}
TEST("require that mutable dense concrete tensors are optimized") {
- TEST_DO(verify_p1_optimized("mut_x5_A-mut_x5_B"));
- TEST_DO(verify_p0_optimized("mut_x5_A-con_x5_B"));
- TEST_DO(verify_p1_optimized("con_x5_A-mut_x5_B"));
- TEST_DO(verify_p1_optimized("mut_x5y3_A-mut_x5y3_B"));
- TEST_DO(verify_p0_optimized("mut_x5y3_A-con_x5y3_B"));
- TEST_DO(verify_p1_optimized("con_x5y3_A-mut_x5y3_B"));
+ TEST_DO(verify_p1_optimized("@x5$1-@x5$2"));
+ TEST_DO(verify_p0_optimized("@x5$1-x5$2"));
+ TEST_DO(verify_p1_optimized("x5$1-@x5$2"));
+ TEST_DO(verify_p1_optimized("@x5y3$1-@x5y3$2"));
+ TEST_DO(verify_p0_optimized("@x5y3$1-x5y3$2"));
+ TEST_DO(verify_p1_optimized("x5y3$1-@x5y3$2"));
}
TEST("require that self-join operations can be optimized") {
- TEST_DO(verify_p0_optimized("mut_x5_A+mut_x5_A"));
+ TEST_DO(verify_p0_optimized("@x5$1+@x5$1"));
}
TEST("require that join(tensor,scalar) operations are optimized") {
- TEST_DO(verify_p0_optimized("mut_x5_A-mut_dbl_B"));
- TEST_DO(verify_p1_optimized("mut_dbl_A-mut_x5_B"));
+ TEST_DO(verify_p0_optimized("@x5$1-mut_dbl_B"));
+ TEST_DO(verify_p1_optimized("mut_dbl_A-@x5$2"));
}
TEST("require that join with different tensor shapes are optimized") {
- TEST_DO(verify_p1_optimized("mut_x5_A*mut_x5y3_B"));
+ TEST_DO(verify_p1_optimized("@x5$1*@x5y3$2"));
}
TEST("require that inplace join operations can be chained") {
- TEST_DO(verify_p2_optimized("mut_x5_A+(mut_x5_B+mut_x5_C)"));
- TEST_DO(verify_p0_optimized("(mut_x5_A+con_x5_B)+con_x5_C"));
- TEST_DO(verify_p1_optimized("con_x5_A+(mut_x5_B+con_x5_C)"));
- TEST_DO(verify_p2_optimized("con_x5_A+(con_x5_B+mut_x5_C)"));
+ TEST_DO(verify_p2_optimized("@x5$1+(@x5$2+@x5$3)"));
+ TEST_DO(verify_p0_optimized("(@x5$1+x5$2)+x5$3"));
+ TEST_DO(verify_p1_optimized("x5$1+(@x5$2+x5$3)"));
+ TEST_DO(verify_p2_optimized("x5$1+(x5$2+@x5$3)"));
}
TEST("require that non-mutable tensors are not optimized") {
- TEST_DO(verify_not_optimized("con_x5_A+con_x5_B"));
+ TEST_DO(verify_not_optimized("x5$1+x5$2"));
}
TEST("require that scalar values are not optimized") {
@@ -114,18 +116,18 @@ TEST("require that scalar values are not optimized") {
}
TEST("require that mapped tensors are not optimized") {
- TEST_DO(verify_not_optimized("mut_x_sparse+mut_x_sparse"));
+ TEST_DO(verify_not_optimized("@x3_1$1+@x3_1$2"));
}
TEST("require that optimization works with float cells") {
- TEST_DO(verify_p1_optimized("mut_x5f_D-mut_x5f_E"));
+ TEST_DO(verify_p1_optimized("@x5$1_f-@x5$2_f"));
}
TEST("require that overwritten value must have same cell type as result") {
- TEST_DO(verify_p0_optimized("mut_x5_A-mut_x5f_D"));
- TEST_DO(verify_p1_optimized("mut_x5f_D-mut_x5_A"));
- TEST_DO(verify_not_optimized("con_x5_A-mut_x5f_D"));
- TEST_DO(verify_not_optimized("mut_x5f_D-con_x5_A"));
+ TEST_DO(verify_p0_optimized("@x5$1-@x5$2_f"));
+ TEST_DO(verify_p1_optimized("@x5$2_f-@x5$1"));
+ TEST_DO(verify_not_optimized("x5$1-@x5$2_f"));
+ TEST_DO(verify_not_optimized("@x5$2_f-x5$1"));
}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp b/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp
index 7b297db3d3e..a2f18d7f7f7 100644
--- a/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp
+++ b/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp
@@ -34,10 +34,15 @@ struct FunInfo {
using LookFor = JoinWithNumberFunction;
Primary primary;
bool inplace;
- void verify(const LookFor &fun) const {
+ void verify(const EvalFixture &fixture, const LookFor &fun) const {
EXPECT_TRUE(fun.result_is_mutable());
EXPECT_EQUAL(fun.primary(), primary);
EXPECT_EQUAL(fun.inplace(), inplace);
+ if (inplace) {
+ size_t idx = (fun.primary() == Primary::LHS) ? 0 : 1;
+ EXPECT_EQUAL(fixture.result_value().cells().data,
+ fixture.param_value(idx).cells().data);
+ }
}
};
diff --git a/eval/src/tests/instruction/mixed_simple_join_function/mixed_simple_join_function_test.cpp b/eval/src/tests/instruction/mixed_simple_join_function/mixed_simple_join_function_test.cpp
index 6220682d716..9150fb604be 100644
--- a/eval/src/tests/instruction/mixed_simple_join_function/mixed_simple_join_function_test.cpp
+++ b/eval/src/tests/instruction/mixed_simple_join_function/mixed_simple_join_function_test.cpp
@@ -1,6 +1,8 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include <vespa/vespalib/testkit/test_kit.h>
+#include <vespa/eval/eval/fast_value.h>
+#include <vespa/eval/eval/simple_value.h>
#include <vespa/eval/eval/tensor_function.h>
#include <vespa/eval/instruction/mixed_simple_join_function.h>
#include <vespa/eval/eval/test/eval_fixture.h>
@@ -15,6 +17,9 @@ using namespace vespalib::eval::tensor_function;
using vespalib::make_string_short::fmt;
+const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();
+const ValueBuilderFactory &test_factory = SimpleValueBuilderFactory::get();
+
using Primary = MixedSimpleJoinFunction::Primary;
using Overlap = MixedSimpleJoinFunction::Overlap;
@@ -41,78 +46,92 @@ std::ostream &operator<<(std::ostream &os, Overlap overlap)
}
-const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();
-
-EvalFixture::ParamRepo make_params() {
- return EvalFixture::ParamRepo()
- .add("a", GenSpec(1.5))
- .add("b", GenSpec(2.5))
- .add("sparse", GenSpec().map("x", {"a", "b", "c"}))
- .add("mixed", GenSpec().map("x", {"a", "b", "c"}).idx("y", 5).idx("z", 3))
- .add("empty_mixed", GenSpec().map("x", {}).idx("y", 5).idx("z", 3))
- .add_mutable("@mixed", GenSpec().map("x", {"a", "b", "c"}).idx("y", 5).idx("z", 3))
- .add_variants("a1b1c1", GenSpec().idx("a", 1).idx("b", 1).idx("c", 1))
- .add_variants("x1y1z1", GenSpec().idx("x", 1).idx("y", 1).idx("z", 1))
- .add_variants("x3y5z3", GenSpec().idx("x", 3).idx("y", 5).idx("z", 3))
- .add_variants("z3", GenSpec().idx("z", 3))
- .add_variants("c5d1", GenSpec().idx("c", 5).idx("d", 1))
- .add_variants("b1c5", GenSpec().idx("b", 1).idx("c", 5))
- .add_variants("x3y5", GenSpec().idx("x", 3).idx("y", 5).seq([](size_t idx) noexcept { return double((idx * 2) + 3); }))
- .add_variants("x3y5$2", GenSpec().idx("x", 3).idx("y", 5).seq([](size_t idx) noexcept { return double((idx * 3) + 2); }))
- .add_variants("y5", GenSpec().idx("y", 5).seq([](size_t idx) noexcept { return double((idx * 2) + 3); }))
- .add_variants("y5$2", GenSpec().idx("y", 5).seq([](size_t idx) noexcept { return double((idx * 3) + 2); }))
- .add_variants("y5z3", GenSpec().idx("y", 5).idx("z", 3).seq([](size_t idx) noexcept { return double((idx * 2) + 3); }))
- .add_variants("y5z3$2", GenSpec().idx("y", 5).idx("z", 3).seq([](size_t idx) noexcept { return double((idx * 3) + 2); }));
-}
-EvalFixture::ParamRepo param_repo = make_params();
-
-void verify_optimized(const vespalib::string &expr, Primary primary, Overlap overlap, bool pri_mut, size_t factor, int p_inplace = -1) {
- EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
- EvalFixture fixture(prod_factory, expr, param_repo, true, true);
- EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
- EXPECT_EQUAL(fixture.result(), slow_fixture.result());
- auto info = fixture.find_all<MixedSimpleJoinFunction>();
- ASSERT_EQUAL(info.size(), 1u);
- EXPECT_TRUE(info[0]->result_is_mutable());
- EXPECT_EQUAL(info[0]->primary(), primary);
- EXPECT_EQUAL(info[0]->overlap(), overlap);
- EXPECT_EQUAL(info[0]->primary_is_mutable(), pri_mut);
- EXPECT_EQUAL(info[0]->factor(), factor);
- EXPECT_TRUE((p_inplace == -1) || (fixture.num_params() > size_t(p_inplace)));
- for (size_t i = 0; i < fixture.num_params(); ++i) {
- if (i == size_t(p_inplace)) {
- EXPECT_EQUAL(fixture.get_param(i), fixture.result());
- } else {
- if (!fixture.result().cells().empty()) {
- EXPECT_NOT_EQUAL(fixture.get_param(i), fixture.result());
+struct FunInfo {
+ using LookFor = MixedSimpleJoinFunction;
+ Overlap overlap;
+ size_t factor;
+ Primary primary;
+ bool l_mut;
+ bool r_mut;
+ bool inplace;
+ void verify(const EvalFixture &fixture, const LookFor &fun) const {
+ EXPECT_TRUE(fun.result_is_mutable());
+ EXPECT_EQUAL(fun.overlap(), overlap);
+ EXPECT_EQUAL(fun.factor(), factor);
+ EXPECT_EQUAL(fun.primary(), primary);
+ if (fun.primary_is_mutable()) {
+ if (fun.primary() == Primary::LHS) {
+ EXPECT_TRUE(l_mut);
}
+ if (fun.primary() == Primary::RHS) {
+ EXPECT_TRUE(r_mut);
+ }
+ }
+ EXPECT_EQUAL(fun.inplace(), inplace);
+ if (fun.inplace()) {
+ EXPECT_TRUE(fun.primary_is_mutable());
+ size_t idx = (fun.primary() == Primary::LHS) ? 0 : 1;
+ EXPECT_EQUAL(fixture.result_value().cells().data,
+ fixture.param_value(idx).cells().data);
+ EXPECT_NOT_EQUAL(fixture.result_value().cells().data,
+ fixture.param_value(1-idx).cells().data);
+ } else {
+ EXPECT_NOT_EQUAL(fixture.result_value().cells().data,
+ fixture.param_value(0).cells().data);
+ EXPECT_NOT_EQUAL(fixture.result_value().cells().data,
+ fixture.param_value(1).cells().data);
}
}
+};
+
+void verify_simple(const vespalib::string &expr, Primary primary, Overlap overlap, size_t factor,
+ bool l_mut, bool r_mut, bool inplace)
+{
+ TEST_STATE(expr.c_str());
+ CellTypeSpace just_double({CellType::DOUBLE}, 2);
+ FunInfo details{overlap, factor, primary, l_mut, r_mut, inplace};
+ EvalFixture::verify<FunInfo>(expr, {details}, just_double);
+ CellTypeSpace just_float({CellType::FLOAT}, 2);
+ EvalFixture::verify<FunInfo>(expr, {details}, just_float);
+}
+
+void verify_optimized(const vespalib::string &expr, Primary primary, Overlap overlap, size_t factor,
+ bool l_mut = false, bool r_mut = false, bool inplace = false)
+{
+ TEST_STATE(expr.c_str());
+ CellTypeSpace all_types(CellTypeUtils::list_types(), 2);
+ FunInfo details{overlap, factor, primary, l_mut, r_mut, inplace};
+ EvalFixture::verify<FunInfo>(expr, {details}, all_types);
}
void verify_not_optimized(const vespalib::string &expr) {
- EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
- EvalFixture fixture(prod_factory, expr, param_repo, true);
- EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
- EXPECT_EQUAL(fixture.result(), slow_fixture.result());
- auto info = fixture.find_all<MixedSimpleJoinFunction>();
- EXPECT_TRUE(info.empty());
+ TEST_STATE(expr.c_str());
+ CellTypeSpace just_double({CellType::DOUBLE}, 2);
+ EvalFixture::verify<FunInfo>(expr, {}, just_double);
}
TEST("require that basic join is optimized") {
- TEST_DO(verify_optimized("y5+y5$2", Primary::RHS, Overlap::FULL, false, 1));
+ TEST_DO(verify_optimized("y5+y5$2", Primary::RHS, Overlap::FULL, 1));
+}
+
+TEST("require that inplace is preferred") {
+ TEST_DO(verify_simple("y5+y5$2", Primary::RHS, Overlap::FULL, 1, false, false, false));
+ TEST_DO(verify_simple("y5+@y5$2", Primary::RHS, Overlap::FULL, 1, false, true, true));
+ TEST_DO(verify_simple("@y5+@y5$2", Primary::RHS, Overlap::FULL, 1, true, true, true));
+ TEST_DO(verify_simple("@y5+y5$2", Primary::LHS, Overlap::FULL, 1, true, false, true));
}
TEST("require that unit join is optimized") {
- TEST_DO(verify_optimized("a1b1c1+x1y1z1", Primary::RHS, Overlap::FULL, false, 1));
+ TEST_DO(verify_optimized("a1b1c1+x1y1z1", Primary::RHS, Overlap::FULL, 1));
}
TEST("require that trivial dimensions do not affect overlap calculation") {
- TEST_DO(verify_optimized("c5d1+b1c5", Primary::RHS, Overlap::FULL, false, 1));
+ TEST_DO(verify_optimized("c5d1+b1c5", Primary::RHS, Overlap::FULL, 1));
+ TEST_DO(verify_simple("@c5d1+@b1c5", Primary::RHS, Overlap::FULL, 1, true, true, true));
}
TEST("require that outer nesting is preferred to inner nesting") {
- TEST_DO(verify_optimized("a1b1c1+y5", Primary::RHS, Overlap::OUTER, false, 5));
+ TEST_DO(verify_optimized("a1b1c1+y5", Primary::RHS, Overlap::OUTER, 5));
}
TEST("require that non-subset join is not optimized") {
@@ -144,42 +163,38 @@ struct LhsRhs {
}
};
-vespalib::string adjust_param(const vespalib::string &str, bool float_cells, bool mut_cells, bool is_rhs) {
- vespalib::string result = str;
- if (mut_cells) {
- result = "@" + result;
- }
- if (is_rhs) {
- result += "$2";
- }
- if (float_cells) {
- result += "_f";
- }
- return result;
-}
-
TEST("require that various parameter combinations work") {
- for (bool left_float: {false, true}) {
- for (bool right_float: {false, true}) {
- bool float_result = (left_float && right_float);
+ for (CellType lct : CellTypeUtils::list_types()) {
+ for (CellType rct : CellTypeUtils::list_types()) {
for (bool left_mut: {false, true}) {
for (bool right_mut: {false, true}) {
- for (const char *op_pattern: {"%s+%s", "%s-%s", "%s*%s"}) {
+ for (const char * expr: {"a+b", "a-b", "a*b"}) {
for (const LhsRhs &params:
- { LhsRhs("y5", "y5", 5, 5, Overlap::FULL),
- LhsRhs("y5", "x3y5", 5, 15, Overlap::INNER),
- LhsRhs("y5", "y5z3", 5, 15, Overlap::OUTER),
- LhsRhs("x3y5", "y5", 15, 5, Overlap::INNER),
- LhsRhs("y5z3", "y5", 15, 5, Overlap::OUTER)})
+ { LhsRhs("y5", "y5", 5, 5, Overlap::FULL),
+ LhsRhs("y5", "x3y5", 5, 15, Overlap::INNER),
+ LhsRhs("y5", "y5z3", 5, 15, Overlap::OUTER),
+ LhsRhs("x3y5", "y5", 15, 5, Overlap::INNER),
+ LhsRhs("y5z3", "y5", 15, 5, Overlap::OUTER)})
{
- vespalib::string left = adjust_param(params.lhs, left_float, left_mut, false);
- vespalib::string right = adjust_param(params.rhs, right_float, right_mut, true);
- vespalib::string expr = fmt(op_pattern, left.c_str(), right.c_str());
- TEST_STATE(expr.c_str());
+ EvalFixture::ParamRepo param_repo;
+ auto a_spec = GenSpec::from_desc(params.lhs).cells(lct).seq(AX_B(0.25, 1.125));
+ auto b_spec = GenSpec::from_desc(params.rhs).cells(rct).seq(AX_B(-0.25, 25.0));
+ if (left_mut) {
+ param_repo.add_mutable("a", a_spec);
+ } else {
+ param_repo.add("a", a_spec);
+ }
+ if (right_mut) {
+ param_repo.add_mutable("b", b_spec);
+ } else {
+ param_repo.add("b", b_spec);
+ }
+ TEST_STATE(expr);
+ CellType result_ct = CellMeta::join(CellMeta{lct, false}, CellMeta{rct, false}).cell_type;
Primary primary = Primary::RHS;
if (params.overlap == Overlap::FULL) {
- bool w_lhs = ((left_float == float_result) && left_mut);
- bool w_rhs = ((right_float == float_result) && right_mut);
+ bool w_lhs = (lct == result_ct) && left_mut;
+ bool w_rhs = (rct == result_ct) && right_mut;
if (w_lhs && !w_rhs) {
primary = Primary::LHS;
}
@@ -187,12 +202,19 @@ TEST("require that various parameter combinations work") {
primary = Primary::LHS;
}
bool pri_mut = (primary == Primary::LHS) ? left_mut : right_mut;
- bool pri_float = (primary == Primary::LHS) ? left_float : right_float;
- int p_inplace = -1;
- if (pri_mut && (pri_float == float_result)) {
- p_inplace = (primary == Primary::LHS) ? 0 : 1;
- }
- verify_optimized(expr, primary, params.overlap, pri_mut, params.factor, p_inplace);
+ bool pri_same_ct = (primary == Primary::LHS) ? (lct == result_ct) : (rct == result_ct);
+ bool inplace = (pri_mut && pri_same_ct);
+ auto expect = EvalFixture::ref(expr, param_repo);
+ EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
+ EvalFixture test_fixture(test_factory, expr, param_repo, true, true);
+ EvalFixture fixture(prod_factory, expr, param_repo, true, true);
+ EXPECT_EQUAL(fixture.result(), expect);
+ EXPECT_EQUAL(slow_fixture.result(), expect);
+ EXPECT_EQUAL(test_fixture.result(), expect);
+ auto info = fixture.find_all<FunInfo::LookFor>();
+ ASSERT_EQUAL(info.size(), 1u);
+ FunInfo details{params.overlap, params.factor, primary, left_mut, right_mut, inplace};
+ details.verify(fixture, *info[0]);
}
}
}
@@ -202,50 +224,56 @@ TEST("require that various parameter combinations work") {
}
TEST("require that scalar values are not optimized") {
- TEST_DO(verify_not_optimized("a+b"));
- TEST_DO(verify_not_optimized("a+y5"));
- TEST_DO(verify_not_optimized("y5+b"));
- TEST_DO(verify_not_optimized("a+sparse"));
- TEST_DO(verify_not_optimized("sparse+a"));
- TEST_DO(verify_not_optimized("a+mixed"));
- TEST_DO(verify_not_optimized("mixed+a"));
+ TEST_DO(verify_not_optimized("reduce(v3,sum)+reduce(v4,sum)"));
+ TEST_DO(verify_not_optimized("reduce(v3,sum)+y5"));
+ TEST_DO(verify_not_optimized("y5+reduce(v3,sum)"));
+ TEST_DO(verify_not_optimized("reduce(v3,sum)+x3_1"));
+ TEST_DO(verify_not_optimized("x3_1+reduce(v3,sum)"));
+ TEST_DO(verify_not_optimized("reduce(v3,sum)+x3_1y5z3"));
+ TEST_DO(verify_not_optimized("x3_1y5z3+reduce(v3,sum)"));
}
TEST("require that sparse tensors are mostly not optimized") {
- TEST_DO(verify_not_optimized("sparse+sparse"));
- TEST_DO(verify_not_optimized("sparse+y5"));
- TEST_DO(verify_not_optimized("y5+sparse"));
- TEST_DO(verify_not_optimized("sparse+mixed"));
- TEST_DO(verify_not_optimized("mixed+sparse"));
+ TEST_DO(verify_not_optimized("x3_1+x3_1$2"));
+ TEST_DO(verify_not_optimized("x3_1+y5"));
+ TEST_DO(verify_not_optimized("y5+x3_1"));
+ TEST_DO(verify_not_optimized("x3_1+x3_1y5z3"));
+ TEST_DO(verify_not_optimized("x3_1y5z3+x3_1"));
}
TEST("require that sparse tensor joined with trivial dense tensor is optimized") {
- TEST_DO(verify_optimized("sparse+a1b1c1", Primary::LHS, Overlap::FULL, false, 1));
- TEST_DO(verify_optimized("a1b1c1+sparse", Primary::RHS, Overlap::FULL, false, 1));
+ TEST_DO(verify_optimized("x3_1+a1b1c1", Primary::LHS, Overlap::FULL, 1));
+ TEST_DO(verify_optimized("a1b1c1+x3_1", Primary::RHS, Overlap::FULL, 1));
}
TEST("require that primary tensor can be empty") {
- TEST_DO(verify_optimized("empty_mixed+y5z3", Primary::LHS, Overlap::FULL, false, 1));
- TEST_DO(verify_optimized("y5z3+empty_mixed", Primary::RHS, Overlap::FULL, false, 1));
+ TEST_DO(verify_optimized("x0_1y5z3+y5z3", Primary::LHS, Overlap::FULL, 1));
+ TEST_DO(verify_optimized("y5z3+x0_1y5z3", Primary::RHS, Overlap::FULL, 1));
}
TEST("require that mixed tensors can be optimized") {
- TEST_DO(verify_not_optimized("mixed+mixed"));
- TEST_DO(verify_optimized("mixed+y5z3", Primary::LHS, Overlap::FULL, false, 1));
- TEST_DO(verify_optimized("mixed+y5", Primary::LHS, Overlap::OUTER, false, 3));
- TEST_DO(verify_optimized("mixed+z3", Primary::LHS, Overlap::INNER, false, 5));
- TEST_DO(verify_optimized("y5z3+mixed", Primary::RHS, Overlap::FULL, false, 1));
- TEST_DO(verify_optimized("y5+mixed", Primary::RHS, Overlap::OUTER, false, 3));
- TEST_DO(verify_optimized("z3+mixed", Primary::RHS, Overlap::INNER, false, 5));
+ TEST_DO(verify_not_optimized("x3_1y5z3+x3_1y5z3$2"));
+ TEST_DO(verify_optimized("x3_1y5z3+y5z3", Primary::LHS, Overlap::FULL, 1));
+ TEST_DO(verify_optimized("x3_1y5z3+y5", Primary::LHS, Overlap::OUTER, 3));
+ TEST_DO(verify_optimized("x3_1y5z3+z3", Primary::LHS, Overlap::INNER, 5));
+ TEST_DO(verify_optimized("y5z3+x3_1y5z3", Primary::RHS, Overlap::FULL, 1));
+ TEST_DO(verify_optimized("y5+x3_1y5z3", Primary::RHS, Overlap::OUTER, 3));
+ TEST_DO(verify_optimized("z3+x3_1y5z3", Primary::RHS, Overlap::INNER, 5));
}
TEST("require that mixed tensors can be inplace") {
- TEST_DO(verify_optimized("@mixed+y5z3", Primary::LHS, Overlap::FULL, true, 1, 0));
- TEST_DO(verify_optimized("@mixed+y5", Primary::LHS, Overlap::OUTER, true, 3, 0));
- TEST_DO(verify_optimized("@mixed+z3", Primary::LHS, Overlap::INNER, true, 5, 0));
- TEST_DO(verify_optimized("y5z3+@mixed", Primary::RHS, Overlap::FULL, true, 1, 1));
- TEST_DO(verify_optimized("y5+@mixed", Primary::RHS, Overlap::OUTER, true, 3, 1));
- TEST_DO(verify_optimized("z3+@mixed", Primary::RHS, Overlap::INNER, true, 5, 1));
+ TEST_DO(verify_simple("@x3_1y5z3+y5z3", Primary::LHS, Overlap::FULL, 1, true, false, true));
+ TEST_DO(verify_simple("@x3_1y5z3+y5", Primary::LHS, Overlap::OUTER, 3, true, false, true));
+ TEST_DO(verify_simple("@x3_1y5z3+z3", Primary::LHS, Overlap::INNER, 5, true, false, true));
+ TEST_DO(verify_simple("@x3_1y5z3+@y5z3", Primary::LHS, Overlap::FULL, 1, true, true, true));
+ TEST_DO(verify_simple("@x3_1y5z3+@y5", Primary::LHS, Overlap::OUTER, 3, true, true, true));
+ TEST_DO(verify_simple("@x3_1y5z3+@z3", Primary::LHS, Overlap::INNER, 5, true, true, true));
+ TEST_DO(verify_simple("y5z3+@x3_1y5z3", Primary::RHS, Overlap::FULL, 1, false, true, true));
+ TEST_DO(verify_simple("y5+@x3_1y5z3", Primary::RHS, Overlap::OUTER, 3, false, true, true));
+ TEST_DO(verify_simple("z3+@x3_1y5z3", Primary::RHS, Overlap::INNER, 5, false, true, true));
+ TEST_DO(verify_simple("@y5z3+@x3_1y5z3", Primary::RHS, Overlap::FULL, 1, true, true, true));
+ TEST_DO(verify_simple("@y5+@x3_1y5z3", Primary::RHS, Overlap::OUTER, 3, true, true, true));
+ TEST_DO(verify_simple("@z3+@x3_1y5z3", Primary::RHS, Overlap::INNER, 5, true, true, true));
}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.cpp b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
index 75626307bfd..77cc1231f9c 100644
--- a/eval/src/vespa/eval/eval/test/eval_fixture.cpp
+++ b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
@@ -191,20 +191,14 @@ EvalFixture::EvalFixture(const ValueBuilderFactory &factory,
_ictx(_ifun),
_param_values(make_params(_factory, *_function, param_repo)),
_params(get_refs(_param_values)),
- _result(spec_from_value(_ifun.eval(_ictx, _params)))
+ _result_value(_ifun.eval(_ictx, _params)),
+ _result(spec_from_value(_result_value))
{
auto result_type = ValueType::from_spec(_result.type());
ASSERT_TRUE(!result_type.is_error());
TEST_DO(detect_param_tampering(param_repo, allow_mutable));
}
-const TensorSpec
-EvalFixture::get_param(size_t idx) const
-{
- ASSERT_LESS(idx, _param_values.size());
- return spec_from_value(*(_param_values[idx]));
-}
-
size_t
EvalFixture::num_params() const
{
diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.h b/eval/src/vespa/eval/eval/test/eval_fixture.h
index dba26f2b270..e99e6755317 100644
--- a/eval/src/vespa/eval/eval/test/eval_fixture.h
+++ b/eval/src/vespa/eval/eval/test/eval_fixture.h
@@ -75,6 +75,7 @@ private:
InterpretedFunction::Context _ictx;
std::vector<Value::UP> _param_values;
SimpleObjectParams _params;
+ const Value &_result_value;
TensorSpec _result;
template <typename T>
@@ -91,6 +92,21 @@ private:
void detect_param_tampering(const ParamRepo &param_repo, bool allow_mutable) const;
+ template <typename FunInfo>
+ auto verify_callback(const FunInfo &verificator,
+ const typename FunInfo::LookFor &what) const
+ -> decltype(verificator.verify(what))
+ {
+ return verificator.verify(what);
+ }
+ template <typename FunInfo>
+ auto verify_callback(const FunInfo &verificator,
+ const typename FunInfo::LookFor &what) const
+ -> decltype(verificator.verify(*this, what))
+ {
+ return verificator.verify(*this, what);
+ }
+
public:
EvalFixture(const ValueBuilderFactory &factory, const vespalib::string &expr, const ParamRepo &param_repo,
bool optimized = true, bool allow_mutable = false);
@@ -101,8 +117,9 @@ public:
find_all(_tensor_function, list);
return list;
}
+ const Value &result_value() const { return _result_value; }
+ const Value &param_value(size_t idx) const { return *(_param_values[idx]); }
const TensorSpec &result() const { return _result; }
- const TensorSpec get_param(size_t idx) const;
size_t num_params() const;
static TensorSpec ref(const vespalib::string &expr, const ParamRepo &param_repo);
static TensorSpec prod(const vespalib::string &expr, const ParamRepo &param_repo) {
@@ -120,7 +137,7 @@ public:
// trailer starting with '$' ('a5b3$2') to allow multiple
// parameters with the same description as well as scalars
// ('$this_is_a_scalar').
-
+
template <typename FunInfo>
static void verify(const vespalib::string &expr, const std::vector<FunInfo> &fun_info, CellTypeSpace cell_type_space) {
auto fun = Function::parse(expr);
@@ -140,7 +157,7 @@ public:
auto info = fixture.find_all<typename FunInfo::LookFor>();
ASSERT_EQUAL(info.size(), fun_info.size());
for (size_t i = 0; i < fun_info.size(); ++i) {
- fun_info[i].verify(*info[i]);
+ fixture.verify_callback<FunInfo>(fun_info[i], *info[i]);
}
}
}
diff --git a/eval/src/vespa/eval/eval/test/gen_spec.cpp b/eval/src/vespa/eval/eval/test/gen_spec.cpp
index 913c23200ff..96c2c16eaa0 100644
--- a/eval/src/vespa/eval/eval/test/gen_spec.cpp
+++ b/eval/src/vespa/eval/eval/test/gen_spec.cpp
@@ -73,7 +73,6 @@ DimSpec::from_desc(const vespalib::string &desc)
assert(idx < desc.size());
assert(is_num(desc[idx]));
size_t num = as_num(desc[idx++]);
- assert(num != 0); // catch leading zeroes/zero size
while ((idx < desc.size()) && is_num(desc[idx])) {
num = (num * 10) + as_num(desc[idx++]);
}
diff --git a/eval/src/vespa/eval/eval/test/gen_spec.h b/eval/src/vespa/eval/eval/test/gen_spec.h
index def24e6711f..f0eca6074dc 100644
--- a/eval/src/vespa/eval/eval/test/gen_spec.h
+++ b/eval/src/vespa/eval/eval/test/gen_spec.h
@@ -121,6 +121,7 @@ public:
const seq_t &seq() const { return _seq; }
GenSpec cpy() const { return *this; }
GenSpec &idx(const vespalib::string &name, size_t size) {
+ assert(size != 0);
_dims.emplace_back(name, size);
return *this;
}
diff --git a/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp b/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp
index 21c6f945609..70134196d1e 100644
--- a/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp
+++ b/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp
@@ -176,14 +176,24 @@ MixedSimpleJoinFunction::MixedSimpleJoinFunction(const ValueType &result_type,
MixedSimpleJoinFunction::~MixedSimpleJoinFunction() = default;
+
+const TensorFunction &
+MixedSimpleJoinFunction::primary_child() const
+{
+ return (_primary == Primary::LHS) ? lhs() : rhs();
+}
+
bool
MixedSimpleJoinFunction::primary_is_mutable() const
{
- if (_primary == Primary::LHS) {
- return lhs().result_is_mutable();
- } else {
- return rhs().result_is_mutable();
- }
+ return primary_child().result_is_mutable();
+}
+
+bool
+MixedSimpleJoinFunction::inplace() const
+{
+ return primary_is_mutable() &&
+ (result_type().cell_type() == primary_child().result_type().cell_type());
}
size_t
diff --git a/eval/src/vespa/eval/instruction/mixed_simple_join_function.h b/eval/src/vespa/eval/instruction/mixed_simple_join_function.h
index 94e5f3c52b5..7658ff93689 100644
--- a/eval/src/vespa/eval/instruction/mixed_simple_join_function.h
+++ b/eval/src/vespa/eval/instruction/mixed_simple_join_function.h
@@ -28,6 +28,7 @@ public:
private:
Primary _primary;
Overlap _overlap;
+ const TensorFunction &primary_child() const;
public:
MixedSimpleJoinFunction(const ValueType &result_type,
const TensorFunction &lhs,
@@ -39,6 +40,7 @@ public:
Primary primary() const { return _primary; }
Overlap overlap() const { return _overlap; }
bool primary_is_mutable() const;
+ bool inplace() const;
size_t factor() const;
InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override;
static const TensorFunction &optimize(const TensorFunction &expr, Stash &stash);