summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-03-19 06:09:32 +0000
committerArne Juul <arnej@verizonmedia.com>2021-03-19 11:26:10 +0000
commite349ad115e423bde38ae1894d9fcb290655b20c7 (patch)
treea6dc3b861eea4b130f26c771028c332ee6630eb4 /eval
parentef422d3dcdfc9cac070446e9c855e3eb242fba43 (diff)
rewrite mixed_simple_join_function_test using EvalFixture::verify
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/instruction/mixed_simple_join_function/mixed_simple_join_function_test.cpp240
-rw-r--r--eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp20
-rw-r--r--eval/src/vespa/eval/instruction/mixed_simple_join_function.h2
3 files changed, 134 insertions, 128 deletions
diff --git a/eval/src/tests/instruction/mixed_simple_join_function/mixed_simple_join_function_test.cpp b/eval/src/tests/instruction/mixed_simple_join_function/mixed_simple_join_function_test.cpp
index 6220682d716..b5f2402a77c 100644
--- a/eval/src/tests/instruction/mixed_simple_join_function/mixed_simple_join_function_test.cpp
+++ b/eval/src/tests/instruction/mixed_simple_join_function/mixed_simple_join_function_test.cpp
@@ -41,78 +41,95 @@ std::ostream &operator<<(std::ostream &os, Overlap overlap)
}
-const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();
-
-EvalFixture::ParamRepo make_params() {
- return EvalFixture::ParamRepo()
- .add("a", GenSpec(1.5))
- .add("b", GenSpec(2.5))
- .add("sparse", GenSpec().map("x", {"a", "b", "c"}))
- .add("mixed", GenSpec().map("x", {"a", "b", "c"}).idx("y", 5).idx("z", 3))
- .add("empty_mixed", GenSpec().map("x", {}).idx("y", 5).idx("z", 3))
- .add_mutable("@mixed", GenSpec().map("x", {"a", "b", "c"}).idx("y", 5).idx("z", 3))
- .add_variants("a1b1c1", GenSpec().idx("a", 1).idx("b", 1).idx("c", 1))
- .add_variants("x1y1z1", GenSpec().idx("x", 1).idx("y", 1).idx("z", 1))
- .add_variants("x3y5z3", GenSpec().idx("x", 3).idx("y", 5).idx("z", 3))
- .add_variants("z3", GenSpec().idx("z", 3))
- .add_variants("c5d1", GenSpec().idx("c", 5).idx("d", 1))
- .add_variants("b1c5", GenSpec().idx("b", 1).idx("c", 5))
- .add_variants("x3y5", GenSpec().idx("x", 3).idx("y", 5).seq([](size_t idx) noexcept { return double((idx * 2) + 3); }))
- .add_variants("x3y5$2", GenSpec().idx("x", 3).idx("y", 5).seq([](size_t idx) noexcept { return double((idx * 3) + 2); }))
- .add_variants("y5", GenSpec().idx("y", 5).seq([](size_t idx) noexcept { return double((idx * 2) + 3); }))
- .add_variants("y5$2", GenSpec().idx("y", 5).seq([](size_t idx) noexcept { return double((idx * 3) + 2); }))
- .add_variants("y5z3", GenSpec().idx("y", 5).idx("z", 3).seq([](size_t idx) noexcept { return double((idx * 2) + 3); }))
- .add_variants("y5z3$2", GenSpec().idx("y", 5).idx("z", 3).seq([](size_t idx) noexcept { return double((idx * 3) + 2); }));
-}
-EvalFixture::ParamRepo param_repo = make_params();
-
-void verify_optimized(const vespalib::string &expr, Primary primary, Overlap overlap, bool pri_mut, size_t factor, int p_inplace = -1) {
- EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
- EvalFixture fixture(prod_factory, expr, param_repo, true, true);
- EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
- EXPECT_EQUAL(fixture.result(), slow_fixture.result());
- auto info = fixture.find_all<MixedSimpleJoinFunction>();
- ASSERT_EQUAL(info.size(), 1u);
- EXPECT_TRUE(info[0]->result_is_mutable());
- EXPECT_EQUAL(info[0]->primary(), primary);
- EXPECT_EQUAL(info[0]->overlap(), overlap);
- EXPECT_EQUAL(info[0]->primary_is_mutable(), pri_mut);
- EXPECT_EQUAL(info[0]->factor(), factor);
- EXPECT_TRUE((p_inplace == -1) || (fixture.num_params() > size_t(p_inplace)));
- for (size_t i = 0; i < fixture.num_params(); ++i) {
- if (i == size_t(p_inplace)) {
- EXPECT_EQUAL(fixture.get_param(i), fixture.result());
- } else {
- if (!fixture.result().cells().empty()) {
- EXPECT_NOT_EQUAL(fixture.get_param(i), fixture.result());
+struct FunInfo {
+ using LookFor = MixedSimpleJoinFunction;
+ Overlap overlap;
+ size_t factor;
+ std::optional<Primary> primary;
+ bool l_mut;
+ bool r_mut;
+ bool require_inplace;
+ void verify(const EvalFixture &fixture, const LookFor &fun) const {
+ EXPECT_TRUE(fun.result_is_mutable());
+ EXPECT_EQUAL(fun.overlap(), overlap);
+ EXPECT_EQUAL(fun.factor(), factor);
+ if (primary.has_value()) {
+ EXPECT_EQUAL(fun.primary(), primary.value());
+ }
+ if (fun.primary_is_mutable()) {
+ if (fun.primary() == Primary::LHS) {
+ EXPECT_TRUE(l_mut);
+ }
+ if (fun.primary() == Primary::RHS) {
+ EXPECT_TRUE(r_mut);
}
}
+ if (require_inplace) {
+ EXPECT_TRUE(fun.inplace());
+ }
+ if (fun.inplace()) {
+ EXPECT_TRUE(fun.primary_is_mutable());
+ size_t idx = (fun.primary() == Primary::LHS) ? 0 : 1;
+ EXPECT_EQUAL(fixture.result_value().cells().data,
+ fixture.param_value(idx).cells().data);
+ }
}
+};
+
+void verify_simple(const vespalib::string &expr, Primary primary, Overlap overlap, size_t factor,
+ bool l_mut, bool r_mut, bool inplace)
+{
+ TEST_STATE(expr.c_str());
+ CellTypeSpace just_double({CellType::DOUBLE}, 2);
+ FunInfo details{overlap, factor, primary, l_mut, r_mut, inplace};
+ EvalFixture::verify<FunInfo>(expr, {details}, just_double);
+}
+
+void verify_optimized(const vespalib::string &expr, Primary primary, Overlap overlap, size_t factor,
+ bool l_mut = false, bool r_mut = false)
+{
+ TEST_STATE(expr.c_str());
+ CellTypeSpace all_types(CellTypeUtils::list_types(), 2);
+ FunInfo details{overlap, factor, primary, l_mut, r_mut, false};
+ EvalFixture::verify<FunInfo>(expr, {details}, all_types);
+}
+
+void verify_optimized(const vespalib::string &expr, Overlap overlap, size_t factor,
+ bool l_mut = false, bool r_mut = false)
+{
+ TEST_STATE(expr.c_str());
+ CellTypeSpace all_types(CellTypeUtils::list_types(), 2);
+ FunInfo details{overlap, factor, std::nullopt, l_mut, r_mut, false};
+ EvalFixture::verify<FunInfo>(expr, {details}, all_types);
}
void verify_not_optimized(const vespalib::string &expr) {
- EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
- EvalFixture fixture(prod_factory, expr, param_repo, true);
- EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
- EXPECT_EQUAL(fixture.result(), slow_fixture.result());
- auto info = fixture.find_all<MixedSimpleJoinFunction>();
- EXPECT_TRUE(info.empty());
+ TEST_STATE(expr.c_str());
+ CellTypeSpace just_double({CellType::DOUBLE}, 2);
+ EvalFixture::verify<FunInfo>(expr, {}, just_double);
}
TEST("require that basic join is optimized") {
- TEST_DO(verify_optimized("y5+y5$2", Primary::RHS, Overlap::FULL, false, 1));
+ TEST_DO(verify_optimized("y5+y5$2", Primary::RHS, Overlap::FULL, 1));
+}
+
+TEST("require that inplace is preferred") {
+ TEST_DO(verify_simple("y5+y5$2", Primary::RHS, Overlap::FULL, 1, false, false, false));
+ TEST_DO(verify_simple("y5+@y5$2", Primary::RHS, Overlap::FULL, 1, false, true, true));
+ TEST_DO(verify_simple("@y5+@y5$2", Primary::RHS, Overlap::FULL, 1, true, true, true));
+ TEST_DO(verify_simple("@y5+y5$2", Primary::LHS, Overlap::FULL, 1, true, false, true));
}
TEST("require that unit join is optimized") {
- TEST_DO(verify_optimized("a1b1c1+x1y1z1", Primary::RHS, Overlap::FULL, false, 1));
+ TEST_DO(verify_optimized("a1b1c1+x1y1z1", Primary::RHS, Overlap::FULL, 1));
}
TEST("require that trivial dimensions do not affect overlap calculation") {
- TEST_DO(verify_optimized("c5d1+b1c5", Primary::RHS, Overlap::FULL, false, 1));
+ TEST_DO(verify_optimized("c5d1+b1c5", Primary::RHS, Overlap::FULL, 1));
}
TEST("require that outer nesting is preferred to inner nesting") {
- TEST_DO(verify_optimized("a1b1c1+y5", Primary::RHS, Overlap::OUTER, false, 5));
+ TEST_DO(verify_optimized("a1b1c1+y5", Primary::RHS, Overlap::OUTER, 5));
}
TEST("require that non-subset join is not optimized") {
@@ -144,7 +161,7 @@ struct LhsRhs {
}
};
-vespalib::string adjust_param(const vespalib::string &str, bool float_cells, bool mut_cells, bool is_rhs) {
+vespalib::string adjust_param(const vespalib::string &str, bool mut_cells, bool is_rhs) {
vespalib::string result = str;
if (mut_cells) {
result = "@" + result;
@@ -152,49 +169,26 @@ vespalib::string adjust_param(const vespalib::string &str, bool float_cells, boo
if (is_rhs) {
result += "$2";
}
- if (float_cells) {
- result += "_f";
- }
return result;
}
TEST("require that various parameter combinations work") {
- for (bool left_float: {false, true}) {
- for (bool right_float: {false, true}) {
- bool float_result = (left_float && right_float);
- for (bool left_mut: {false, true}) {
- for (bool right_mut: {false, true}) {
- for (const char *op_pattern: {"%s+%s", "%s-%s", "%s*%s"}) {
- for (const LhsRhs &params:
- { LhsRhs("y5", "y5", 5, 5, Overlap::FULL),
- LhsRhs("y5", "x3y5", 5, 15, Overlap::INNER),
- LhsRhs("y5", "y5z3", 5, 15, Overlap::OUTER),
- LhsRhs("x3y5", "y5", 15, 5, Overlap::INNER),
- LhsRhs("y5z3", "y5", 15, 5, Overlap::OUTER)})
- {
- vespalib::string left = adjust_param(params.lhs, left_float, left_mut, false);
- vespalib::string right = adjust_param(params.rhs, right_float, right_mut, true);
- vespalib::string expr = fmt(op_pattern, left.c_str(), right.c_str());
- TEST_STATE(expr.c_str());
- Primary primary = Primary::RHS;
- if (params.overlap == Overlap::FULL) {
- bool w_lhs = ((left_float == float_result) && left_mut);
- bool w_rhs = ((right_float == float_result) && right_mut);
- if (w_lhs && !w_rhs) {
- primary = Primary::LHS;
- }
- } else if (params.lhs_size > params.rhs_size) {
- primary = Primary::LHS;
- }
- bool pri_mut = (primary == Primary::LHS) ? left_mut : right_mut;
- bool pri_float = (primary == Primary::LHS) ? left_float : right_float;
- int p_inplace = -1;
- if (pri_mut && (pri_float == float_result)) {
- p_inplace = (primary == Primary::LHS) ? 0 : 1;
- }
- verify_optimized(expr, primary, params.overlap, pri_mut, params.factor, p_inplace);
- }
- }
+ for (bool left_mut: {false, true}) {
+ for (bool right_mut: {false, true}) {
+ for (const char *op_pattern: {"%s+%s", "%s-%s", "%s*%s"}) {
+ for (const LhsRhs &params:
+ { LhsRhs("y5", "y5", 5, 5, Overlap::FULL),
+ LhsRhs("y5", "x3y5", 5, 15, Overlap::INNER),
+ LhsRhs("y5", "y5z3", 5, 15, Overlap::OUTER),
+ LhsRhs("x3y5", "y5", 15, 5, Overlap::INNER),
+ LhsRhs("y5z3", "y5", 15, 5, Overlap::OUTER)})
+ {
+ vespalib::string left = adjust_param(params.lhs, left_mut, false);
+ vespalib::string right = adjust_param(params.rhs, right_mut, true);
+ vespalib::string expr = fmt(op_pattern, left.c_str(), right.c_str());
+ TEST_STATE(expr.c_str());
+ verify_optimized(expr, params.overlap, params.factor,
+ left_mut, right_mut);
}
}
}
@@ -202,50 +196,50 @@ TEST("require that various parameter combinations work") {
}
TEST("require that scalar values are not optimized") {
- TEST_DO(verify_not_optimized("a+b"));
- TEST_DO(verify_not_optimized("a+y5"));
- TEST_DO(verify_not_optimized("y5+b"));
- TEST_DO(verify_not_optimized("a+sparse"));
- TEST_DO(verify_not_optimized("sparse+a"));
- TEST_DO(verify_not_optimized("a+mixed"));
- TEST_DO(verify_not_optimized("mixed+a"));
+ TEST_DO(verify_not_optimized("reduce(v3,sum)+reduce(v4,sum)"));
+ TEST_DO(verify_not_optimized("reduce(v3,sum)+y5"));
+ TEST_DO(verify_not_optimized("y5+reduce(v3,sum)"));
+ TEST_DO(verify_not_optimized("reduce(v3,sum)+x3_1"));
+ TEST_DO(verify_not_optimized("x3_1+reduce(v3,sum)"));
+ TEST_DO(verify_not_optimized("reduce(v3,sum)+x3_1y5z3"));
+ TEST_DO(verify_not_optimized("x3_1y5z3+reduce(v3,sum)"));
}
TEST("require that sparse tensors are mostly not optimized") {
- TEST_DO(verify_not_optimized("sparse+sparse"));
- TEST_DO(verify_not_optimized("sparse+y5"));
- TEST_DO(verify_not_optimized("y5+sparse"));
- TEST_DO(verify_not_optimized("sparse+mixed"));
- TEST_DO(verify_not_optimized("mixed+sparse"));
+ TEST_DO(verify_not_optimized("x3_1+x3_1$2"));
+ TEST_DO(verify_not_optimized("x3_1+y5"));
+ TEST_DO(verify_not_optimized("y5+x3_1"));
+ TEST_DO(verify_not_optimized("x3_1+x3_1y5z3"));
+ TEST_DO(verify_not_optimized("x3_1y5z3+x3_1"));
}
TEST("require that sparse tensor joined with trivial dense tensor is optimized") {
- TEST_DO(verify_optimized("sparse+a1b1c1", Primary::LHS, Overlap::FULL, false, 1));
- TEST_DO(verify_optimized("a1b1c1+sparse", Primary::RHS, Overlap::FULL, false, 1));
+ TEST_DO(verify_optimized("x3_1+a1b1c1", Primary::LHS, Overlap::FULL, 1));
+ TEST_DO(verify_optimized("a1b1c1+x3_1", Primary::RHS, Overlap::FULL, 1));
}
TEST("require that primary tensor can be empty") {
- TEST_DO(verify_optimized("empty_mixed+y5z3", Primary::LHS, Overlap::FULL, false, 1));
- TEST_DO(verify_optimized("y5z3+empty_mixed", Primary::RHS, Overlap::FULL, false, 1));
+ TEST_DO(verify_optimized("x0_1y5z3+y5z3", Primary::LHS, Overlap::FULL, 1));
+ TEST_DO(verify_optimized("y5z3+x0_1y5z3", Primary::RHS, Overlap::FULL, 1));
}
TEST("require that mixed tensors can be optimized") {
- TEST_DO(verify_not_optimized("mixed+mixed"));
- TEST_DO(verify_optimized("mixed+y5z3", Primary::LHS, Overlap::FULL, false, 1));
- TEST_DO(verify_optimized("mixed+y5", Primary::LHS, Overlap::OUTER, false, 3));
- TEST_DO(verify_optimized("mixed+z3", Primary::LHS, Overlap::INNER, false, 5));
- TEST_DO(verify_optimized("y5z3+mixed", Primary::RHS, Overlap::FULL, false, 1));
- TEST_DO(verify_optimized("y5+mixed", Primary::RHS, Overlap::OUTER, false, 3));
- TEST_DO(verify_optimized("z3+mixed", Primary::RHS, Overlap::INNER, false, 5));
+ TEST_DO(verify_not_optimized("x3_1y5z3+x3_1y5z3$2"));
+ TEST_DO(verify_optimized("x3_1y5z3+y5z3", Primary::LHS, Overlap::FULL, 1));
+ TEST_DO(verify_optimized("x3_1y5z3+y5", Primary::LHS, Overlap::OUTER, 3));
+ TEST_DO(verify_optimized("x3_1y5z3+z3", Primary::LHS, Overlap::INNER, 5));
+ TEST_DO(verify_optimized("y5z3+x3_1y5z3", Primary::RHS, Overlap::FULL, 1));
+ TEST_DO(verify_optimized("y5+x3_1y5z3", Primary::RHS, Overlap::OUTER, 3));
+ TEST_DO(verify_optimized("z3+x3_1y5z3", Primary::RHS, Overlap::INNER, 5));
}
TEST("require that mixed tensors can be inplace") {
- TEST_DO(verify_optimized("@mixed+y5z3", Primary::LHS, Overlap::FULL, true, 1, 0));
- TEST_DO(verify_optimized("@mixed+y5", Primary::LHS, Overlap::OUTER, true, 3, 0));
- TEST_DO(verify_optimized("@mixed+z3", Primary::LHS, Overlap::INNER, true, 5, 0));
- TEST_DO(verify_optimized("y5z3+@mixed", Primary::RHS, Overlap::FULL, true, 1, 1));
- TEST_DO(verify_optimized("y5+@mixed", Primary::RHS, Overlap::OUTER, true, 3, 1));
- TEST_DO(verify_optimized("z3+@mixed", Primary::RHS, Overlap::INNER, true, 5, 1));
+ TEST_DO(verify_optimized("@x3_1y5z3+y5z3", Primary::LHS, Overlap::FULL, 1, true, false));
+ TEST_DO(verify_optimized("@x3_1y5z3+y5", Primary::LHS, Overlap::OUTER, 3, true, false));
+ TEST_DO(verify_optimized("@x3_1y5z3+z3", Primary::LHS, Overlap::INNER, 5, true, false));
+ TEST_DO(verify_optimized("y5z3+@x3_1y5z3", Primary::RHS, Overlap::FULL, 1, false, true));
+ TEST_DO(verify_optimized("y5+@x3_1y5z3", Primary::RHS, Overlap::OUTER, 3, false, true));
+ TEST_DO(verify_optimized("z3+@x3_1y5z3", Primary::RHS, Overlap::INNER, 5, false, true));
}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp b/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp
index 21c6f945609..69e01469bfa 100644
--- a/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp
+++ b/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp
@@ -176,14 +176,24 @@ MixedSimpleJoinFunction::MixedSimpleJoinFunction(const ValueType &result_type,
MixedSimpleJoinFunction::~MixedSimpleJoinFunction() = default;
+
+const TensorFunction &
+MixedSimpleJoinFunction::primary_child() const
+{
+ return (_primary == Primary::LHS) ? lhs() : rhs();
+}
+
bool
MixedSimpleJoinFunction::primary_is_mutable() const
{
- if (_primary == Primary::LHS) {
- return lhs().result_is_mutable();
- } else {
- return rhs().result_is_mutable();
- }
+ return primary_child().result_is_mutable();
+}
+
+bool
+MixedSimpleJoinFunction::inplace() const
+{
+ return primary_is_mutable() &&
+ (result_type() == primary_child().result_type());
}
size_t
diff --git a/eval/src/vespa/eval/instruction/mixed_simple_join_function.h b/eval/src/vespa/eval/instruction/mixed_simple_join_function.h
index 94e5f3c52b5..7658ff93689 100644
--- a/eval/src/vespa/eval/instruction/mixed_simple_join_function.h
+++ b/eval/src/vespa/eval/instruction/mixed_simple_join_function.h
@@ -28,6 +28,7 @@ public:
private:
Primary _primary;
Overlap _overlap;
+ const TensorFunction &primary_child() const;
public:
MixedSimpleJoinFunction(const ValueType &result_type,
const TensorFunction &lhs,
@@ -39,6 +40,7 @@ public:
Primary primary() const { return _primary; }
Overlap overlap() const { return _overlap; }
bool primary_is_mutable() const;
+ bool inplace() const;
size_t factor() const;
InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override;
static const TensorFunction &optimize(const TensorFunction &expr, Stash &stash);