aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-11-10 16:15:14 +0000
committerArne Juul <arnej@verizonmedia.com>2020-11-11 16:03:17 +0000
commit3968f683fd4a6883d896cda698a34729e7338148 (patch)
treed27201acd92f4665f3bc0684cfc55f0b0b113dc1 /eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp
parent923e66e078e7b9cc0dbc6399432c9e3172f62658 (diff)
optimize join with number, with unit test
Diffstat (limited to 'eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp')
-rw-r--r--eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp136
1 files changed, 136 insertions, 0 deletions
diff --git a/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp b/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp
new file mode 100644
index 00000000000..a67fc3725ca
--- /dev/null
+++ b/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp
@@ -0,0 +1,136 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/vespalib/testkit/test_kit.h>
+#include <vespa/eval/eval/fast_value.h>
+#include <vespa/eval/eval/tensor_function.h>
+#include <vespa/eval/eval/test/eval_fixture.h>
+#include <vespa/eval/eval/test/tensor_model.hpp>
+#include <vespa/eval/instruction/join_with_number_function.h>
+
+#include <vespa/vespalib/util/stringfmt.h>
+
+using namespace vespalib;
+using namespace vespalib::eval;
+using namespace vespalib::eval::test;
+using namespace vespalib::eval::tensor_function;
+
+using vespalib::make_string_short::fmt;
+
+using Primary = JoinWithNumberFunction::Primary;
+
+namespace vespalib::eval {
+
+std::ostream &operator<<(std::ostream &os, Primary primary)
+{
+ switch(primary) {
+ case Primary::LHS: return os << "LHS";
+ case Primary::RHS: return os << "RHS";
+ }
+ abort();
+}
+
+}
+
+const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();
+
+EvalFixture::ParamRepo make_params() {
+ return EvalFixture::ParamRepo()
+ .add("a", spec(1.5))
+ .add("number", spec(2.5))
+ .add("sparse", spec({x({"a"})}, N()))
+ .add("dense", spec({y(5)}, N()))
+ .add("mixed", spec({x({"a"}),y(5)}, N()))
+ .add("mixed_float", spec(float_cells({x({"a"}),y(5)}), N()))
+ .add("mixed_inplace", spec({x({"a"}),y(5)}, N()), true)
+ .add_matrix("x", 3, "y", 5);
+}
+EvalFixture::ParamRepo param_repo = make_params();
+
+void verify_optimized(const vespalib::string &expr, Primary primary, bool inplace) {
+ EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
+ EvalFixture fixture(prod_factory, expr, param_repo, true, true);
+ EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQUAL(fixture.result(), slow_fixture.result());
+ auto info = fixture.find_all<JoinWithNumberFunction>();
+ ASSERT_EQUAL(info.size(), 1u);
+ EXPECT_TRUE(info[0]->result_is_mutable());
+ EXPECT_EQUAL(info[0]->primary(), primary);
+ EXPECT_EQUAL(info[0]->inplace(), inplace);
+ int p_inplace = inplace ? ((primary == Primary::LHS) ? 0 : 1) : -1;
+ EXPECT_TRUE((p_inplace == -1) || (fixture.num_params() > size_t(p_inplace)));
+ for (size_t i = 0; i < fixture.num_params(); ++i) {
+ if (i == size_t(p_inplace)) {
+ EXPECT_EQUAL(fixture.get_param(i), fixture.result());
+ } else {
+ EXPECT_NOT_EQUAL(fixture.get_param(i), fixture.result());
+ }
+ }
+}
+
+void verify_not_optimized(const vespalib::string &expr) {
+ EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
+ EvalFixture fixture(prod_factory, expr, param_repo, true);
+ EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQUAL(fixture.result(), slow_fixture.result());
+ auto info = fixture.find_all<JoinWithNumberFunction>();
+ EXPECT_TRUE(info.empty());
+}
+
+TEST("require that dense number join can be optimized") {
+ TEST_DO(verify_optimized("x3y5+a", Primary::LHS, false));
+ TEST_DO(verify_optimized("a+x3y5", Primary::RHS, false));
+ TEST_DO(verify_optimized("x3y5f*a", Primary::LHS, false));
+ TEST_DO(verify_optimized("a*x3y5f", Primary::RHS, false));
+}
+
+TEST("require that dense number join can be inplace") {
+ TEST_DO(verify_optimized("@x3y5*a", Primary::LHS, true));
+ TEST_DO(verify_optimized("a*@x3y5", Primary::RHS, true));
+ TEST_DO(verify_optimized("@x3y5f+a", Primary::LHS, true));
+ TEST_DO(verify_optimized("a+@x3y5f", Primary::RHS, true));
+}
+
+TEST("require that asymmetric operations work") {
+ TEST_DO(verify_optimized("x3y5/a", Primary::LHS, false));
+ TEST_DO(verify_optimized("a/x3y5", Primary::RHS, false));
+ TEST_DO(verify_optimized("x3y5f-a", Primary::LHS, false));
+ TEST_DO(verify_optimized("a-x3y5f", Primary::RHS, false));
+}
+
+TEST("require that mixed number join can be optimized") {
+ TEST_DO(verify_optimized("mixed+a", Primary::LHS, false));
+ TEST_DO(verify_optimized("a+mixed", Primary::RHS, false));
+ TEST_DO(verify_optimized("mixed<a", Primary::LHS, false));
+ TEST_DO(verify_optimized("a<mixed", Primary::RHS, false));
+ TEST_DO(verify_optimized("mixed_float+a", Primary::LHS, false));
+ TEST_DO(verify_optimized("a+mixed_float", Primary::RHS, false));
+ TEST_DO(verify_optimized("mixed_float<a", Primary::LHS, false));
+ TEST_DO(verify_optimized("a<mixed_float", Primary::RHS, false));
+}
+
+TEST("require that mixed number join can be inplace") {
+ TEST_DO(verify_optimized("mixed_inplace+a", Primary::LHS, true));
+ TEST_DO(verify_optimized("a+mixed_inplace", Primary::RHS, true));
+ TEST_DO(verify_optimized("mixed_inplace<a", Primary::LHS, true));
+ TEST_DO(verify_optimized("a<mixed_inplace", Primary::RHS, true));
+}
+
+TEST("require that all appropriate cases are optimized, others not") {
+ int optimized = 0;
+ for (vespalib::string lhs: {"number", "dense", "sparse", "mixed"}) {
+ for (vespalib::string rhs: {"number", "dense", "sparse", "mixed"}) {
+ auto expr = fmt("%s+%s", lhs.c_str(), rhs.c_str());
+ TEST_STATE(expr.c_str());
+ if ((lhs == "number") != (rhs == "number")) {
+ auto which = (rhs == "number") ? Primary::LHS : Primary::RHS;
+ verify_optimized(expr, which, false);
+ ++optimized;
+ } else {
+ verify_not_optimized(expr);
+ }
+ }
+ }
+ EXPECT_EQUAL(optimized, 6);
+}
+
+TEST_MAIN() { TEST_RUN_ALL(); }