aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--eval/CMakeLists.txt1
-rw-r--r--eval/src/tests/instruction/join_with_number/CMakeLists.txt8
-rw-r--r--eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp136
-rw-r--r--eval/src/vespa/eval/eval/optimize_tensor_function.cpp4
-rw-r--r--eval/src/vespa/eval/instruction/CMakeLists.txt1
-rw-r--r--eval/src/vespa/eval/instruction/join_with_number_function.cpp116
-rw-r--r--eval/src/vespa/eval/instruction/join_with_number_function.h35
7 files changed, 299 insertions, 2 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt
index 19b75c7ff46..7d7d1c08f4e 100644
--- a/eval/CMakeLists.txt
+++ b/eval/CMakeLists.txt
@@ -45,6 +45,7 @@ vespa_define_module(
src/tests/instruction/generic_peek
src/tests/instruction/generic_reduce
src/tests/instruction/generic_rename
+ src/tests/instruction/join_with_number
src/tests/tensor/default_value_builder_factory
src/tests/tensor/dense_add_dimension_optimizer
src/tests/tensor/dense_dimension_combiner
diff --git a/eval/src/tests/instruction/join_with_number/CMakeLists.txt b/eval/src/tests/instruction/join_with_number/CMakeLists.txt
new file mode 100644
index 00000000000..9bb170eae17
--- /dev/null
+++ b/eval/src/tests/instruction/join_with_number/CMakeLists.txt
@@ -0,0 +1,8 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(eval_join_with_number_function_test_app TEST
+ SOURCES
+ join_with_number_function_test.cpp
+ DEPENDS
+ vespaeval
+)
+vespa_add_test(NAME eval_join_with_number_function_test_app COMMAND eval_join_with_number_function_test_app)
diff --git a/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp b/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp
new file mode 100644
index 00000000000..a67fc3725ca
--- /dev/null
+++ b/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp
@@ -0,0 +1,136 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/vespalib/testkit/test_kit.h>
+#include <vespa/eval/eval/fast_value.h>
+#include <vespa/eval/eval/tensor_function.h>
+#include <vespa/eval/eval/test/eval_fixture.h>
+#include <vespa/eval/eval/test/tensor_model.hpp>
+#include <vespa/eval/instruction/join_with_number_function.h>
+
+#include <vespa/vespalib/util/stringfmt.h>
+
+using namespace vespalib;
+using namespace vespalib::eval;
+using namespace vespalib::eval::test;
+using namespace vespalib::eval::tensor_function;
+
+using vespalib::make_string_short::fmt;
+
+using Primary = JoinWithNumberFunction::Primary;
+
+namespace vespalib::eval {
+
+std::ostream &operator<<(std::ostream &os, Primary primary)
+{
+ switch(primary) {
+ case Primary::LHS: return os << "LHS";
+ case Primary::RHS: return os << "RHS";
+ }
+ abort();
+}
+
+}
+
+const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();
+
+EvalFixture::ParamRepo make_params() {
+ return EvalFixture::ParamRepo()
+ .add("a", spec(1.5))
+ .add("number", spec(2.5))
+ .add("sparse", spec({x({"a"})}, N()))
+ .add("dense", spec({y(5)}, N()))
+ .add("mixed", spec({x({"a"}),y(5)}, N()))
+ .add("mixed_float", spec(float_cells({x({"a"}),y(5)}), N()))
+ .add("mixed_inplace", spec({x({"a"}),y(5)}, N()), true)
+ .add_matrix("x", 3, "y", 5);
+}
+EvalFixture::ParamRepo param_repo = make_params();
+
+void verify_optimized(const vespalib::string &expr, Primary primary, bool inplace) {
+ EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
+ EvalFixture fixture(prod_factory, expr, param_repo, true, true);
+ EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQUAL(fixture.result(), slow_fixture.result());
+ auto info = fixture.find_all<JoinWithNumberFunction>();
+ ASSERT_EQUAL(info.size(), 1u);
+ EXPECT_TRUE(info[0]->result_is_mutable());
+ EXPECT_EQUAL(info[0]->primary(), primary);
+ EXPECT_EQUAL(info[0]->inplace(), inplace);
+ int p_inplace = inplace ? ((primary == Primary::LHS) ? 0 : 1) : -1;
+ EXPECT_TRUE((p_inplace == -1) || (fixture.num_params() > size_t(p_inplace)));
+ for (size_t i = 0; i < fixture.num_params(); ++i) {
+ if (i == size_t(p_inplace)) {
+ EXPECT_EQUAL(fixture.get_param(i), fixture.result());
+ } else {
+ EXPECT_NOT_EQUAL(fixture.get_param(i), fixture.result());
+ }
+ }
+}
+
+void verify_not_optimized(const vespalib::string &expr) {
+ EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
+ EvalFixture fixture(prod_factory, expr, param_repo, true);
+ EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQUAL(fixture.result(), slow_fixture.result());
+ auto info = fixture.find_all<JoinWithNumberFunction>();
+ EXPECT_TRUE(info.empty());
+}
+
+TEST("require that dense number join can be optimized") {
+ TEST_DO(verify_optimized("x3y5+a", Primary::LHS, false));
+ TEST_DO(verify_optimized("a+x3y5", Primary::RHS, false));
+ TEST_DO(verify_optimized("x3y5f*a", Primary::LHS, false));
+ TEST_DO(verify_optimized("a*x3y5f", Primary::RHS, false));
+}
+
+TEST("require that dense number join can be inplace") {
+ TEST_DO(verify_optimized("@x3y5*a", Primary::LHS, true));
+ TEST_DO(verify_optimized("a*@x3y5", Primary::RHS, true));
+ TEST_DO(verify_optimized("@x3y5f+a", Primary::LHS, true));
+ TEST_DO(verify_optimized("a+@x3y5f", Primary::RHS, true));
+}
+
+TEST("require that asymmetric operations work") {
+ TEST_DO(verify_optimized("x3y5/a", Primary::LHS, false));
+ TEST_DO(verify_optimized("a/x3y5", Primary::RHS, false));
+ TEST_DO(verify_optimized("x3y5f-a", Primary::LHS, false));
+ TEST_DO(verify_optimized("a-x3y5f", Primary::RHS, false));
+}
+
+TEST("require that mixed number join can be optimized") {
+ TEST_DO(verify_optimized("mixed+a", Primary::LHS, false));
+ TEST_DO(verify_optimized("a+mixed", Primary::RHS, false));
+ TEST_DO(verify_optimized("mixed<a", Primary::LHS, false));
+ TEST_DO(verify_optimized("a<mixed", Primary::RHS, false));
+ TEST_DO(verify_optimized("mixed_float+a", Primary::LHS, false));
+ TEST_DO(verify_optimized("a+mixed_float", Primary::RHS, false));
+ TEST_DO(verify_optimized("mixed_float<a", Primary::LHS, false));
+ TEST_DO(verify_optimized("a<mixed_float", Primary::RHS, false));
+}
+
+TEST("require that mixed number join can be inplace") {
+ TEST_DO(verify_optimized("mixed_inplace+a", Primary::LHS, true));
+ TEST_DO(verify_optimized("a+mixed_inplace", Primary::RHS, true));
+ TEST_DO(verify_optimized("mixed_inplace<a", Primary::LHS, true));
+ TEST_DO(verify_optimized("a<mixed_inplace", Primary::RHS, true));
+}
+
+TEST("require that all appropriate cases are optimized, others not") {
+ int optimized = 0;
+ for (vespalib::string lhs: {"number", "dense", "sparse", "mixed"}) {
+ for (vespalib::string rhs: {"number", "dense", "sparse", "mixed"}) {
+ auto expr = fmt("%s+%s", lhs.c_str(), rhs.c_str());
+ TEST_STATE(expr.c_str());
+ if ((lhs == "number") != (rhs == "number")) {
+ auto which = (rhs == "number") ? Primary::LHS : Primary::RHS;
+ verify_optimized(expr, which, false);
+ ++optimized;
+ } else {
+ verify_not_optimized(expr);
+ }
+ }
+ }
+ EXPECT_EQUAL(optimized, 6);
+}
+
+TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
index 83f806178e8..75ae3307599 100644
--- a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
@@ -17,7 +17,7 @@
#include <vespa/eval/tensor/dense/dense_lambda_function.h>
#include <vespa/eval/tensor/dense/dense_simple_expand_function.h>
#include <vespa/eval/tensor/dense/dense_simple_join_function.h>
-#include <vespa/eval/tensor/dense/dense_number_join_function.h>
+#include <vespa/eval/instruction/join_with_number_function.h>
#include <vespa/eval/tensor/dense/dense_pow_as_map_optimizer.h>
#include <vespa/eval/tensor/dense/dense_simple_map_function.h>
#include <vespa/eval/tensor/dense/vector_from_doubles_function.h>
@@ -73,7 +73,7 @@ const TensorFunction &optimize_for_factory(const ValueBuilderFactory &factory, c
child.set(DensePowAsMapOptimizer::optimize(child.get(), stash));
child.set(DenseSimpleMapFunction::optimize(child.get(), stash));
child.set(DenseSimpleJoinFunction::optimize(child.get(), stash));
- child.set(DenseNumberJoinFunction::optimize(child.get(), stash));
+ child.set(JoinWithNumberFunction::optimize(child.get(), stash));
child.set(DenseSingleReduceFunction::optimize(child.get(), stash));
nodes.pop_back();
}
diff --git a/eval/src/vespa/eval/instruction/CMakeLists.txt b/eval/src/vespa/eval/instruction/CMakeLists.txt
index 926a69bd291..7211ebbfb49 100644
--- a/eval/src/vespa/eval/instruction/CMakeLists.txt
+++ b/eval/src/vespa/eval/instruction/CMakeLists.txt
@@ -11,4 +11,5 @@ vespa_add_library(eval_instruction OBJECT
generic_peek.cpp
generic_reduce.cpp
generic_rename.cpp
+ join_with_number_function.cpp
)
diff --git a/eval/src/vespa/eval/instruction/join_with_number_function.cpp b/eval/src/vespa/eval/instruction/join_with_number_function.cpp
new file mode 100644
index 00000000000..dd3512a5e74
--- /dev/null
+++ b/eval/src/vespa/eval/instruction/join_with_number_function.cpp
@@ -0,0 +1,116 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "join_with_number_function.h"
+#include <vespa/vespalib/objects/objectvisitor.h>
+#include <vespa/eval/eval/value.h>
+#include <vespa/eval/eval/operation.h>
+#include <vespa/eval/eval/inline_operation.h>
+#include <vespa/vespalib/util/typify.h>
+
+using namespace vespalib::eval::tensor_function;
+using namespace vespalib::eval::operation;
+
+namespace vespalib::eval {
+
+using Instruction = InterpretedFunction::Instruction;
+using State = InterpretedFunction::State;
+
+namespace {
+
+template <typename CT, bool inplace>
+ArrayRef<CT> make_dst_cells(ConstArrayRef<CT> src_cells, Stash &stash) {
+ if (inplace) {
+ return unconstify(src_cells);
+ } else {
+ return stash.create_uninitialized_array<CT>(src_cells.size());
+ }
+}
+
+template <typename CT, typename Fun, bool inplace, bool swap>
+void my_number_join_op(State &state, uint64_t param) {
+ using OP = typename std::conditional<swap,SwapArgs2<Fun>,Fun>::type;
+ OP my_op((join_fun_t)param);
+ const Value &tensor = state.peek(swap ? 0 : 1);
+ CT number = state.peek(swap ? 1 : 0).as_double();
+ auto src_cells = tensor.cells().typify<CT>();
+ auto dst_cells = make_dst_cells<CT, inplace>(src_cells, state.stash);
+ apply_op2_vec_num(dst_cells.begin(), src_cells.begin(), number, dst_cells.size(), my_op);
+ if (inplace) {
+ state.pop_pop_push(tensor);
+ } else {
+ state.pop_pop_push(state.stash.create<ValueView>(tensor.type(), tensor.index(), TypedCells(dst_cells)));
+ }
+}
+
+struct SelectJoinWithNumberOp {
+ template<typename CT, typename Fun,
+ typename InputIsMutable, typename NumberWasLeft>
+ static auto invoke() {
+ return my_number_join_op<CT, Fun, InputIsMutable::value, NumberWasLeft::value>;
+ }
+};
+
+} // namespace <unnamed>
+
+JoinWithNumberFunction::JoinWithNumberFunction(const Join &original, bool tensor_was_right)
+ : tensor_function::Op2(original.result_type(), original.lhs(), original.rhs()),
+ _primary(tensor_was_right ? Primary::RHS : Primary::LHS),
+ _function(original.function())
+{
+}
+
+JoinWithNumberFunction::~JoinWithNumberFunction() = default;
+
+bool
+JoinWithNumberFunction::inplace() const {
+ if (_primary == Primary::LHS) {
+ return lhs().result_is_mutable();
+ } else {
+ return rhs().result_is_mutable();
+ }
+}
+
+using MyTypify = TypifyValue<TypifyCellType,vespalib::TypifyBool,operation::TypifyOp2>;
+
+InterpretedFunction::Instruction
+JoinWithNumberFunction::compile_self(EngineOrFactory, Stash &) const
+{
+ auto op = typify_invoke<4,MyTypify,SelectJoinWithNumberOp>(result_type().cell_type(),
+ _function,
+ inplace(),
+ (_primary == Primary::RHS));
+ return Instruction(op, (uint64_t)(_function));
+}
+
+void
+JoinWithNumberFunction::visit_self(vespalib::ObjectVisitor &visitor) const
+{
+ Super::visit_self(visitor);
+ visitor.visitBool("tensor_was_right", (_primary == Primary::RHS));
+ visitor.visitBool("is_inplace", inplace());
+}
+
+const TensorFunction &
+JoinWithNumberFunction::optimize(const TensorFunction &expr, Stash &stash)
+{
+ if (! expr.result_type().is_scalar()) {
+ if (const auto *join = as<Join>(expr)) {
+ const ValueType &result_type = join->result_type();
+ const TensorFunction &lhs = join->lhs();
+ const TensorFunction &rhs = join->rhs();
+ if (lhs.result_type().is_double() &&
+ (result_type == rhs.result_type()))
+ {
+ return stash.create<JoinWithNumberFunction>(*join, true);
+ }
+ if (rhs.result_type().is_double() &&
+ (result_type == lhs.result_type()))
+ {
+ return stash.create<JoinWithNumberFunction>(*join, false);
+ }
+ }
+ }
+ return expr;
+}
+
+} // namespace
diff --git a/eval/src/vespa/eval/instruction/join_with_number_function.h b/eval/src/vespa/eval/instruction/join_with_number_function.h
new file mode 100644
index 00000000000..6e3f9aa4106
--- /dev/null
+++ b/eval/src/vespa/eval/instruction/join_with_number_function.h
@@ -0,0 +1,35 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <vespa/eval/eval/value_type.h>
+#include <vespa/eval/eval/tensor_function.h>
+
+namespace vespalib::eval {
+
+/**
+ * Tensor function for joining a general tensor with a number
+ */
+class JoinWithNumberFunction : public tensor_function::Op2
+{
+public:
+ enum class Primary : uint8_t { LHS, RHS };
+private:
+ using Super = tensor_function::Op2;
+ Primary _primary;
+ tensor_function::join_fun_t _function;
+public:
+
+ JoinWithNumberFunction(const vespalib::eval::tensor_function::Join &original_join, bool number_on_left);
+ ~JoinWithNumberFunction();
+ Primary primary() const { return _primary; }
+ bool inplace() const;
+ bool result_is_mutable() const override { return true; }
+
+ InterpretedFunction::Instruction compile_self(EngineOrFactory engine, Stash &stash) const override;
+ void visit_self(vespalib::ObjectVisitor &visitor) const override;
+ static const TensorFunction &optimize(const TensorFunction &expr, Stash &stash);
+};
+
+} // namespace vespalib::tensor
+