diff options
author | Arne Juul <arnej@verizonmedia.com> | 2021-03-19 06:09:32 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2021-03-19 11:26:10 +0000 |
commit | e349ad115e423bde38ae1894d9fcb290655b20c7 (patch) | |
tree | a6dc3b861eea4b130f26c771028c332ee6630eb4 /eval/src | |
parent | ef422d3dcdfc9cac070446e9c855e3eb242fba43 (diff) |
rewrite mixed_simple_join_function_test using EvalFixture::verify
Diffstat (limited to 'eval/src')
3 files changed, 134 insertions, 128 deletions
diff --git a/eval/src/tests/instruction/mixed_simple_join_function/mixed_simple_join_function_test.cpp b/eval/src/tests/instruction/mixed_simple_join_function/mixed_simple_join_function_test.cpp index 6220682d716..b5f2402a77c 100644 --- a/eval/src/tests/instruction/mixed_simple_join_function/mixed_simple_join_function_test.cpp +++ b/eval/src/tests/instruction/mixed_simple_join_function/mixed_simple_join_function_test.cpp @@ -41,78 +41,95 @@ std::ostream &operator<<(std::ostream &os, Overlap overlap) } -const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get(); - -EvalFixture::ParamRepo make_params() { - return EvalFixture::ParamRepo() - .add("a", GenSpec(1.5)) - .add("b", GenSpec(2.5)) - .add("sparse", GenSpec().map("x", {"a", "b", "c"})) - .add("mixed", GenSpec().map("x", {"a", "b", "c"}).idx("y", 5).idx("z", 3)) - .add("empty_mixed", GenSpec().map("x", {}).idx("y", 5).idx("z", 3)) - .add_mutable("@mixed", GenSpec().map("x", {"a", "b", "c"}).idx("y", 5).idx("z", 3)) - .add_variants("a1b1c1", GenSpec().idx("a", 1).idx("b", 1).idx("c", 1)) - .add_variants("x1y1z1", GenSpec().idx("x", 1).idx("y", 1).idx("z", 1)) - .add_variants("x3y5z3", GenSpec().idx("x", 3).idx("y", 5).idx("z", 3)) - .add_variants("z3", GenSpec().idx("z", 3)) - .add_variants("c5d1", GenSpec().idx("c", 5).idx("d", 1)) - .add_variants("b1c5", GenSpec().idx("b", 1).idx("c", 5)) - .add_variants("x3y5", GenSpec().idx("x", 3).idx("y", 5).seq([](size_t idx) noexcept { return double((idx * 2) + 3); })) - .add_variants("x3y5$2", GenSpec().idx("x", 3).idx("y", 5).seq([](size_t idx) noexcept { return double((idx * 3) + 2); })) - .add_variants("y5", GenSpec().idx("y", 5).seq([](size_t idx) noexcept { return double((idx * 2) + 3); })) - .add_variants("y5$2", GenSpec().idx("y", 5).seq([](size_t idx) noexcept { return double((idx * 3) + 2); })) - .add_variants("y5z3", GenSpec().idx("y", 5).idx("z", 3).seq([](size_t idx) noexcept { return double((idx * 2) + 3); })) - .add_variants("y5z3$2", GenSpec().idx("y", 5).idx("z", 3).seq([](size_t idx) noexcept { return double((idx * 3) + 2); })); -} -EvalFixture::ParamRepo param_repo = make_params(); - -void verify_optimized(const vespalib::string &expr, Primary primary, Overlap overlap, bool pri_mut, size_t factor, int p_inplace = -1) { - EvalFixture slow_fixture(prod_factory, expr, param_repo, false); - EvalFixture fixture(prod_factory, expr, param_repo, true, true); - EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo)); - EXPECT_EQUAL(fixture.result(), slow_fixture.result()); - auto info = fixture.find_all<MixedSimpleJoinFunction>(); - ASSERT_EQUAL(info.size(), 1u); - EXPECT_TRUE(info[0]->result_is_mutable()); - EXPECT_EQUAL(info[0]->primary(), primary); - EXPECT_EQUAL(info[0]->overlap(), overlap); - EXPECT_EQUAL(info[0]->primary_is_mutable(), pri_mut); - EXPECT_EQUAL(info[0]->factor(), factor); - EXPECT_TRUE((p_inplace == -1) || (fixture.num_params() > size_t(p_inplace))); - for (size_t i = 0; i < fixture.num_params(); ++i) { - if (i == size_t(p_inplace)) { - EXPECT_EQUAL(fixture.get_param(i), fixture.result()); - } else { - if (!fixture.result().cells().empty()) { - EXPECT_NOT_EQUAL(fixture.get_param(i), fixture.result()); +struct FunInfo { + using LookFor = MixedSimpleJoinFunction; + Overlap overlap; + size_t factor; + std::optional<Primary> primary; + bool l_mut; + bool r_mut; + bool require_inplace; + void verify(const EvalFixture &fixture, const LookFor &fun) const { + EXPECT_TRUE(fun.result_is_mutable()); + EXPECT_EQUAL(fun.overlap(), overlap); + EXPECT_EQUAL(fun.factor(), factor); + if (primary.has_value()) { + EXPECT_EQUAL(fun.primary(), primary.value()); + } + if (fun.primary_is_mutable()) { + if (fun.primary() == Primary::LHS) { + EXPECT_TRUE(l_mut); + } + if (fun.primary() == Primary::RHS) { + EXPECT_TRUE(r_mut); } } + if (require_inplace) { + EXPECT_TRUE(fun.inplace()); + } + if (fun.inplace()) { + EXPECT_TRUE(fun.primary_is_mutable()); + size_t idx = (fun.primary() == Primary::LHS) ? 0 : 1; + EXPECT_EQUAL(fixture.result_value().cells().data, + fixture.param_value(idx).cells().data); + } } +}; + +void verify_simple(const vespalib::string &expr, Primary primary, Overlap overlap, size_t factor, + bool l_mut, bool r_mut, bool inplace) +{ + TEST_STATE(expr.c_str()); + CellTypeSpace just_double({CellType::DOUBLE}, 2); + FunInfo details{overlap, factor, primary, l_mut, r_mut, inplace}; + EvalFixture::verify<FunInfo>(expr, {details}, just_double); +} + +void verify_optimized(const vespalib::string &expr, Primary primary, Overlap overlap, size_t factor, + bool l_mut = false, bool r_mut = false) +{ + TEST_STATE(expr.c_str()); + CellTypeSpace all_types(CellTypeUtils::list_types(), 2); + FunInfo details{overlap, factor, primary, l_mut, r_mut, false}; + EvalFixture::verify<FunInfo>(expr, {details}, all_types); +} + +void verify_optimized(const vespalib::string &expr, Overlap overlap, size_t factor, + bool l_mut = false, bool r_mut = false) +{ + TEST_STATE(expr.c_str()); + CellTypeSpace all_types(CellTypeUtils::list_types(), 2); + FunInfo details{overlap, factor, std::nullopt, l_mut, r_mut, false}; + EvalFixture::verify<FunInfo>(expr, {details}, all_types); } void verify_not_optimized(const vespalib::string &expr) { - EvalFixture slow_fixture(prod_factory, expr, param_repo, false); - EvalFixture fixture(prod_factory, expr, param_repo, true); - EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo)); - EXPECT_EQUAL(fixture.result(), slow_fixture.result()); - auto info = fixture.find_all<MixedSimpleJoinFunction>(); - EXPECT_TRUE(info.empty()); + TEST_STATE(expr.c_str()); + CellTypeSpace just_double({CellType::DOUBLE}, 2); + EvalFixture::verify<FunInfo>(expr, {}, just_double); } TEST("require that basic join is optimized") { - TEST_DO(verify_optimized("y5+y5$2", Primary::RHS, Overlap::FULL, false, 1)); + TEST_DO(verify_optimized("y5+y5$2", Primary::RHS, Overlap::FULL, 1)); +} + +TEST("require that inplace is preferred") { + TEST_DO(verify_simple("y5+y5$2", Primary::RHS, Overlap::FULL, 1, false, false, false)); + TEST_DO(verify_simple("y5+@y5$2", Primary::RHS, Overlap::FULL, 1, false, true, true)); + TEST_DO(verify_simple("@y5+@y5$2", Primary::RHS, Overlap::FULL, 1, true, true, true)); + TEST_DO(verify_simple("@y5+y5$2", Primary::LHS, Overlap::FULL, 1, true, false, true)); } TEST("require that unit join is optimized") { - TEST_DO(verify_optimized("a1b1c1+x1y1z1", Primary::RHS, Overlap::FULL, false, 1)); + TEST_DO(verify_optimized("a1b1c1+x1y1z1", Primary::RHS, Overlap::FULL, 1)); } TEST("require that trivial dimensions do not affect overlap calculation") { - TEST_DO(verify_optimized("c5d1+b1c5", Primary::RHS, Overlap::FULL, false, 1)); + TEST_DO(verify_optimized("c5d1+b1c5", Primary::RHS, Overlap::FULL, 1)); } TEST("require that outer nesting is preferred to inner nesting") { - TEST_DO(verify_optimized("a1b1c1+y5", Primary::RHS, Overlap::OUTER, false, 5)); + TEST_DO(verify_optimized("a1b1c1+y5", Primary::RHS, Overlap::OUTER, 5)); } TEST("require that non-subset join is not optimized") { @@ -144,7 +161,7 @@ struct LhsRhs { } }; -vespalib::string adjust_param(const vespalib::string &str, bool float_cells, bool mut_cells, bool is_rhs) { +vespalib::string adjust_param(const vespalib::string &str, bool mut_cells, bool is_rhs) { vespalib::string result = str; if (mut_cells) { result = "@" + result; @@ -152,49 +169,26 @@ vespalib::string adjust_param(const vespalib::string &str, bool float_cells, boo if (is_rhs) { result += "$2"; } - if (float_cells) { - result += "_f"; - } return result; } TEST("require that various parameter combinations work") { - for (bool left_float: {false, true}) { - for (bool right_float: {false, true}) { - bool float_result = (left_float && right_float); - for (bool left_mut: {false, true}) { - for (bool right_mut: {false, true}) { - for (const char *op_pattern: {"%s+%s", "%s-%s", "%s*%s"}) { - for (const LhsRhs ¶ms: - { LhsRhs("y5", "y5", 5, 5, Overlap::FULL), - LhsRhs("y5", "x3y5", 5, 15, Overlap::INNER), - LhsRhs("y5", "y5z3", 5, 15, Overlap::OUTER), - LhsRhs("x3y5", "y5", 15, 5, Overlap::INNER), - LhsRhs("y5z3", "y5", 15, 5, Overlap::OUTER)}) - { - vespalib::string left = adjust_param(params.lhs, left_float, left_mut, false); - vespalib::string right = adjust_param(params.rhs, right_float, right_mut, true); - vespalib::string expr = fmt(op_pattern, left.c_str(), right.c_str()); - TEST_STATE(expr.c_str()); - Primary primary = Primary::RHS; - if (params.overlap == Overlap::FULL) { - bool w_lhs = ((left_float == float_result) && left_mut); - bool w_rhs = ((right_float == float_result) && right_mut); - if (w_lhs && !w_rhs) { - primary = Primary::LHS; - } - } else if (params.lhs_size > params.rhs_size) { - primary = Primary::LHS; - } - bool pri_mut = (primary == Primary::LHS) ? left_mut : right_mut; - bool pri_float = (primary == Primary::LHS) ? left_float : right_float; - int p_inplace = -1; - if (pri_mut && (pri_float == float_result)) { - p_inplace = (primary == Primary::LHS) ? 0 : 1; - } - verify_optimized(expr, primary, params.overlap, pri_mut, params.factor, p_inplace); - } - } + for (bool left_mut: {false, true}) { + for (bool right_mut: {false, true}) { + for (const char *op_pattern: {"%s+%s", "%s-%s", "%s*%s"}) { + for (const LhsRhs ¶ms: + { LhsRhs("y5", "y5", 5, 5, Overlap::FULL), + LhsRhs("y5", "x3y5", 5, 15, Overlap::INNER), + LhsRhs("y5", "y5z3", 5, 15, Overlap::OUTER), + LhsRhs("x3y5", "y5", 15, 5, Overlap::INNER), + LhsRhs("y5z3", "y5", 15, 5, Overlap::OUTER)}) + { + vespalib::string left = adjust_param(params.lhs, left_mut, false); + vespalib::string right = adjust_param(params.rhs, right_mut, true); + vespalib::string expr = fmt(op_pattern, left.c_str(), right.c_str()); + TEST_STATE(expr.c_str()); + verify_optimized(expr, params.overlap, params.factor, + left_mut, right_mut); } } } @@ -202,50 +196,50 @@ TEST("require that various parameter combinations work") { } TEST("require that scalar values are not optimized") { - TEST_DO(verify_not_optimized("a+b")); - TEST_DO(verify_not_optimized("a+y5")); - TEST_DO(verify_not_optimized("y5+b")); - TEST_DO(verify_not_optimized("a+sparse")); - TEST_DO(verify_not_optimized("sparse+a")); - TEST_DO(verify_not_optimized("a+mixed")); - TEST_DO(verify_not_optimized("mixed+a")); + TEST_DO(verify_not_optimized("reduce(v3,sum)+reduce(v4,sum)")); + TEST_DO(verify_not_optimized("reduce(v3,sum)+y5")); + TEST_DO(verify_not_optimized("y5+reduce(v3,sum)")); + TEST_DO(verify_not_optimized("reduce(v3,sum)+x3_1")); + TEST_DO(verify_not_optimized("x3_1+reduce(v3,sum)")); + TEST_DO(verify_not_optimized("reduce(v3,sum)+x3_1y5z3")); + TEST_DO(verify_not_optimized("x3_1y5z3+reduce(v3,sum)")); } TEST("require that sparse tensors are mostly not optimized") { - TEST_DO(verify_not_optimized("sparse+sparse")); - TEST_DO(verify_not_optimized("sparse+y5")); - TEST_DO(verify_not_optimized("y5+sparse")); - TEST_DO(verify_not_optimized("sparse+mixed")); - TEST_DO(verify_not_optimized("mixed+sparse")); + TEST_DO(verify_not_optimized("x3_1+x3_1$2")); + TEST_DO(verify_not_optimized("x3_1+y5")); + TEST_DO(verify_not_optimized("y5+x3_1")); + TEST_DO(verify_not_optimized("x3_1+x3_1y5z3")); + TEST_DO(verify_not_optimized("x3_1y5z3+x3_1")); } TEST("require that sparse tensor joined with trivial dense tensor is optimized") { - TEST_DO(verify_optimized("sparse+a1b1c1", Primary::LHS, Overlap::FULL, false, 1)); - TEST_DO(verify_optimized("a1b1c1+sparse", Primary::RHS, Overlap::FULL, false, 1)); + TEST_DO(verify_optimized("x3_1+a1b1c1", Primary::LHS, Overlap::FULL, 1)); + TEST_DO(verify_optimized("a1b1c1+x3_1", Primary::RHS, Overlap::FULL, 1)); } TEST("require that primary tensor can be empty") { - TEST_DO(verify_optimized("empty_mixed+y5z3", Primary::LHS, Overlap::FULL, false, 1)); - TEST_DO(verify_optimized("y5z3+empty_mixed", Primary::RHS, Overlap::FULL, false, 1)); + TEST_DO(verify_optimized("x0_1y5z3+y5z3", Primary::LHS, Overlap::FULL, 1)); + TEST_DO(verify_optimized("y5z3+x0_1y5z3", Primary::RHS, Overlap::FULL, 1)); } TEST("require that mixed tensors can be optimized") { - TEST_DO(verify_not_optimized("mixed+mixed")); - TEST_DO(verify_optimized("mixed+y5z3", Primary::LHS, Overlap::FULL, false, 1)); - TEST_DO(verify_optimized("mixed+y5", Primary::LHS, Overlap::OUTER, false, 3)); - TEST_DO(verify_optimized("mixed+z3", Primary::LHS, Overlap::INNER, false, 5)); - TEST_DO(verify_optimized("y5z3+mixed", Primary::RHS, Overlap::FULL, false, 1)); - TEST_DO(verify_optimized("y5+mixed", Primary::RHS, Overlap::OUTER, false, 3)); - TEST_DO(verify_optimized("z3+mixed", Primary::RHS, Overlap::INNER, false, 5)); + TEST_DO(verify_not_optimized("x3_1y5z3+x3_1y5z3$2")); + TEST_DO(verify_optimized("x3_1y5z3+y5z3", Primary::LHS, Overlap::FULL, 1)); + TEST_DO(verify_optimized("x3_1y5z3+y5", Primary::LHS, Overlap::OUTER, 3)); + TEST_DO(verify_optimized("x3_1y5z3+z3", Primary::LHS, Overlap::INNER, 5)); + TEST_DO(verify_optimized("y5z3+x3_1y5z3", Primary::RHS, Overlap::FULL, 1)); + TEST_DO(verify_optimized("y5+x3_1y5z3", Primary::RHS, Overlap::OUTER, 3)); + TEST_DO(verify_optimized("z3+x3_1y5z3", Primary::RHS, Overlap::INNER, 5)); } TEST("require that mixed tensors can be inplace") { - TEST_DO(verify_optimized("@mixed+y5z3", Primary::LHS, Overlap::FULL, true, 1, 0)); - TEST_DO(verify_optimized("@mixed+y5", Primary::LHS, Overlap::OUTER, true, 3, 0)); - TEST_DO(verify_optimized("@mixed+z3", Primary::LHS, Overlap::INNER, true, 5, 0)); - TEST_DO(verify_optimized("y5z3+@mixed", Primary::RHS, Overlap::FULL, true, 1, 1)); - TEST_DO(verify_optimized("y5+@mixed", Primary::RHS, Overlap::OUTER, true, 3, 1)); - TEST_DO(verify_optimized("z3+@mixed", Primary::RHS, Overlap::INNER, true, 5, 1)); + TEST_DO(verify_optimized("@x3_1y5z3+y5z3", Primary::LHS, Overlap::FULL, 1, true, false)); + TEST_DO(verify_optimized("@x3_1y5z3+y5", Primary::LHS, Overlap::OUTER, 3, true, false)); + TEST_DO(verify_optimized("@x3_1y5z3+z3", Primary::LHS, Overlap::INNER, 5, true, false)); + TEST_DO(verify_optimized("y5z3+@x3_1y5z3", Primary::RHS, Overlap::FULL, 1, false, true)); + TEST_DO(verify_optimized("y5+@x3_1y5z3", Primary::RHS, Overlap::OUTER, 3, false, true)); + TEST_DO(verify_optimized("z3+@x3_1y5z3", Primary::RHS, Overlap::INNER, 5, false, true)); } TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp b/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp index 21c6f945609..69e01469bfa 100644 --- a/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp +++ b/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp @@ -176,14 +176,24 @@ MixedSimpleJoinFunction::MixedSimpleJoinFunction(const ValueType &result_type, MixedSimpleJoinFunction::~MixedSimpleJoinFunction() = default; + +const TensorFunction & +MixedSimpleJoinFunction::primary_child() const +{ + return (_primary == Primary::LHS) ? lhs() : rhs(); +} + bool MixedSimpleJoinFunction::primary_is_mutable() const { - if (_primary == Primary::LHS) { - return lhs().result_is_mutable(); - } else { - return rhs().result_is_mutable(); - } + return primary_child().result_is_mutable(); +} + +bool +MixedSimpleJoinFunction::inplace() const +{ + return primary_is_mutable() && + (result_type() == primary_child().result_type()); } size_t diff --git a/eval/src/vespa/eval/instruction/mixed_simple_join_function.h b/eval/src/vespa/eval/instruction/mixed_simple_join_function.h index 94e5f3c52b5..7658ff93689 100644 --- a/eval/src/vespa/eval/instruction/mixed_simple_join_function.h +++ b/eval/src/vespa/eval/instruction/mixed_simple_join_function.h @@ -28,6 +28,7 @@ public: private: Primary _primary; Overlap _overlap; + const TensorFunction &primary_child() const; public: MixedSimpleJoinFunction(const ValueType &result_type, const TensorFunction &lhs, @@ -39,6 +40,7 @@ public: Primary primary() const { return _primary; } Overlap overlap() const { return _overlap; } bool primary_is_mutable() const; + bool inplace() const; size_t factor() const; InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override; static const TensorFunction &optimize(const TensorFunction &expr, Stash &stash); |