summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2021-02-05 16:15:04 +0100
committerGitHub <noreply@github.com>2021-02-05 16:15:04 +0100
commit0e5959f756feb6f0883388215152772fa5d8181f (patch)
tree1fd9a52d1df8d362a38960aa8952e3ec40bd2593 /eval
parent4c6758eea33c9b28a1667730f12d6106a60cec67 (diff)
parentdd2cc138985dde640bddee367e896ed1ca6679e7 (diff)
Merge pull request #16390 from vespa-engine/arnej/refactor-sparse-merge-optimiser
Arnej/refactor sparse merge optimiser
Diffstat (limited to 'eval')
-rw-r--r--eval/CMakeLists.txt1
-rw-r--r--eval/src/tests/instruction/sparse_merge_function/CMakeLists.txt9
-rw-r--r--eval/src/tests/instruction/sparse_merge_function/sparse_merge_function_test.cpp82
-rw-r--r--eval/src/vespa/eval/eval/fast_value.hpp34
-rw-r--r--eval/src/vespa/eval/eval/optimize_tensor_function.cpp2
-rw-r--r--eval/src/vespa/eval/instruction/CMakeLists.txt1
-rw-r--r--eval/src/vespa/eval/instruction/generic_merge.cpp69
-rw-r--r--eval/src/vespa/eval/instruction/generic_merge.h33
-rw-r--r--eval/src/vespa/eval/instruction/sparse_merge_function.cpp146
-rw-r--r--eval/src/vespa/eval/instruction/sparse_merge_function.h22
10 files changed, 301 insertions, 98 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt
index 23127cc12b5..0cba519bf88 100644
--- a/eval/CMakeLists.txt
+++ b/eval/CMakeLists.txt
@@ -66,6 +66,7 @@ vespa_define_module(
src/tests/instruction/pow_as_map_optimizer
src/tests/instruction/remove_trivial_dimension_optimizer
src/tests/instruction/sparse_dot_product_function
+ src/tests/instruction/sparse_merge_function
src/tests/instruction/sum_max_dot_product_function
src/tests/instruction/vector_from_doubles_function
src/tests/streamed/value
diff --git a/eval/src/tests/instruction/sparse_merge_function/CMakeLists.txt b/eval/src/tests/instruction/sparse_merge_function/CMakeLists.txt
new file mode 100644
index 00000000000..f905bdd8c1b
--- /dev/null
+++ b/eval/src/tests/instruction/sparse_merge_function/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(eval_sparse_merge_function_test_app TEST
+ SOURCES
+ sparse_merge_function_test.cpp
+ DEPENDS
+ vespaeval
+ GTest::GTest
+)
+vespa_add_test(NAME eval_sparse_merge_function_test_app COMMAND eval_sparse_merge_function_test_app)
diff --git a/eval/src/tests/instruction/sparse_merge_function/sparse_merge_function_test.cpp b/eval/src/tests/instruction/sparse_merge_function/sparse_merge_function_test.cpp
new file mode 100644
index 00000000000..e175286e18c
--- /dev/null
+++ b/eval/src/tests/instruction/sparse_merge_function/sparse_merge_function_test.cpp
@@ -0,0 +1,82 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/eval/eval/fast_value.h>
+#include <vespa/eval/eval/simple_value.h>
+#include <vespa/eval/instruction/sparse_merge_function.h>
+#include <vespa/eval/eval/test/eval_fixture.h>
+#include <vespa/eval/eval/test/gen_spec.h>
+#include <vespa/vespalib/gtest/gtest.h>
+
+using namespace vespalib::eval;
+using namespace vespalib::eval::test;
+
+const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();
+const ValueBuilderFactory &test_factory = SimpleValueBuilderFactory::get();
+
+//-----------------------------------------------------------------------------
+
+EvalFixture::ParamRepo make_params() {
+ return EvalFixture::ParamRepo()
+ .add("scalar1", GenSpec(1.0).gen())
+ .add("scalar2", GenSpec(2.0).gen())
+ .add_variants("v1_x", GenSpec(3.0).map("x", 32, 1))
+ .add_variants("v2_x", GenSpec(4.0).map("x", 16, 2))
+ .add_variants("v3_xz", GenSpec(5.0).map("x", 16, 2).idx("z", 1))
+ .add("dense", GenSpec(6.0).idx("x", 10).gen())
+ .add("m1_xy", GenSpec(7.0).map("x", 32, 1).map("y", 16, 2).gen())
+ .add("m2_xy", GenSpec(8.0).map("x", 16, 2).map("y", 32, 1).gen())
+ .add("mixed", GenSpec(9.0).map("x", 8, 1).idx("y", 5).gen());
+}
+EvalFixture::ParamRepo param_repo = make_params();
+
+void assert_optimized(const vespalib::string &expr) {
+ EvalFixture fast_fixture(prod_factory, expr, param_repo, true);
+ EvalFixture test_fixture(test_factory, expr, param_repo, true);
+ EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
+ EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQ(test_fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQ(slow_fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQ(fast_fixture.find_all<SparseMergeFunction>().size(), 1u);
+ EXPECT_EQ(test_fixture.find_all<SparseMergeFunction>().size(), 1u);
+ EXPECT_EQ(slow_fixture.find_all<SparseMergeFunction>().size(), 0u);
+}
+
+void assert_not_optimized(const vespalib::string &expr) {
+ EvalFixture fast_fixture(prod_factory, expr, param_repo, true);
+ EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQ(fast_fixture.find_all<SparseMergeFunction>().size(), 0u);
+}
+
+//-----------------------------------------------------------------------------
+
+TEST(SparseMerge, expression_can_be_optimized)
+{
+ assert_optimized("merge(v1_x,v2_x,f(x,y)(x+y))");
+ assert_optimized("merge(v1_x,v2_x,f(x,y)(max(x,y)))");
+ assert_optimized("merge(v1_x,v2_x,f(x,y)(x+y+1))");
+ assert_optimized("merge(v1_x_f,v2_x_f,f(x,y)(x+y))");
+ assert_optimized("merge(v3_xz,v3_xz,f(x,y)(x+y))");
+}
+
+TEST(SparseMerge, multi_dimensional_expression_can_be_optimized)
+{
+ assert_optimized("merge(m1_xy,m2_xy,f(x,y)(x+y))");
+ assert_optimized("merge(m1_xy,m2_xy,f(x,y)(x*y))");
+}
+
+TEST(SparseMerge, similar_expressions_are_not_optimized)
+{
+ assert_not_optimized("merge(scalar1,scalar2,f(x,y)(x+y))");
+ assert_not_optimized("merge(dense,dense,f(x,y)(x+y))");
+ assert_not_optimized("merge(mixed,mixed,f(x,y)(x+y))");
+}
+
+TEST(SparseMerge, mixed_cell_types_are_not_optimized)
+{
+ assert_not_optimized("merge(v1_x,v2_x_f,f(x,y)(x+y))");
+ assert_not_optimized("merge(v1_x_f,v2_x,f(x,y)(x+y))");
+}
+
+//-----------------------------------------------------------------------------
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/eval/src/vespa/eval/eval/fast_value.hpp b/eval/src/vespa/eval/eval/fast_value.hpp
index 88319df7590..6673494ccd2 100644
--- a/eval/src/vespa/eval/eval/fast_value.hpp
+++ b/eval/src/vespa/eval/eval/fast_value.hpp
@@ -155,12 +155,6 @@ struct FastValueIndex final : Value::Index {
const std::vector<JoinAddrSource> &addr_sources,
ConstArrayRef<LCT> lhs_cells, ConstArrayRef<RCT> rhs_cells, Stash &stash);
- template <typename LCT, typename RCT, typename OCT, typename Fun>
- static const Value &sparse_only_merge(const ValueType &res_type, const Fun &fun,
- const FastValueIndex &lhs, const FastValueIndex &rhs,
- ConstArrayRef<LCT> lhs_cells, ConstArrayRef<RCT> rhs_cells,
- Stash &stash) __attribute((noinline));
-
size_t size() const override { return map.size(); }
std::unique_ptr<View> create_view(const std::vector<size_t> &dims) const override;
};
@@ -429,32 +423,4 @@ FastValueIndex::sparse_no_overlap_join(const ValueType &res_type, const Fun &fun
//-----------------------------------------------------------------------------
-template <typename LCT, typename RCT, typename OCT, typename Fun>
-const Value &
-FastValueIndex::sparse_only_merge(const ValueType &res_type, const Fun &fun,
- const FastValueIndex &lhs, const FastValueIndex &rhs,
- ConstArrayRef<LCT> lhs_cells, ConstArrayRef<RCT> rhs_cells, Stash &stash)
-{
- size_t guess_size = lhs.map.size() + rhs.map.size();
- auto &result = stash.create<FastValue<OCT,true>>(res_type, lhs.map.addr_size(), 1, guess_size);
- lhs.map.each_map_entry([&](auto lhs_subspace, auto hash)
- {
- result.add_mapping(lhs.map.get_addr(lhs_subspace), hash);
- result.my_cells.push_back_fast(lhs_cells[lhs_subspace]);
- });
- rhs.map.each_map_entry([&](auto rhs_subspace, auto hash)
- {
- auto rhs_addr = rhs.map.get_addr(rhs_subspace);
- auto result_subspace = result.my_index.map.lookup(rhs_addr, hash);
- if (result_subspace == FastAddrMap::npos()) {
- result.add_mapping(rhs_addr, hash);
- result.my_cells.push_back_fast(rhs_cells[rhs_subspace]);
- } else {
- OCT &out_cell = *result.my_cells.get(result_subspace);
- out_cell = fun(out_cell, rhs_cells[rhs_subspace]);
- }
- });
- return result;
-}
-
}
diff --git a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
index 25612b8d5fd..196e8a98679 100644
--- a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
@@ -6,6 +6,7 @@
#include <vespa/eval/instruction/dense_dot_product_function.h>
#include <vespa/eval/instruction/sparse_dot_product_function.h>
+#include <vespa/eval/instruction/sparse_merge_function.h>
#include <vespa/eval/instruction/mixed_inner_product_function.h>
#include <vespa/eval/instruction/sum_max_dot_product_function.h>
#include <vespa/eval/instruction/dense_xw_product_function.h>
@@ -72,6 +73,7 @@ const TensorFunction &optimize_for_factory(const ValueBuilderFactory &, const Te
child.set(MixedSimpleJoinFunction::optimize(child.get(), stash));
child.set(JoinWithNumberFunction::optimize(child.get(), stash));
child.set(DenseSingleReduceFunction::optimize(child.get(), stash));
+ child.set(SparseMergeFunction::optimize(child.get(), stash));
nodes.pop_back();
}
}
diff --git a/eval/src/vespa/eval/instruction/CMakeLists.txt b/eval/src/vespa/eval/instruction/CMakeLists.txt
index cac69d23640..3def8907ac8 100644
--- a/eval/src/vespa/eval/instruction/CMakeLists.txt
+++ b/eval/src/vespa/eval/instruction/CMakeLists.txt
@@ -33,6 +33,7 @@ vespa_add_library(eval_instruction OBJECT
remove_trivial_dimension_optimizer.cpp
replace_type_function.cpp
sparse_dot_product_function.cpp
+ sparse_merge_function.cpp
sum_max_dot_product_function.cpp
vector_from_doubles_function.cpp
)
diff --git a/eval/src/vespa/eval/instruction/generic_merge.cpp b/eval/src/vespa/eval/instruction/generic_merge.cpp
index b40388aa547..9b098db7763 100644
--- a/eval/src/vespa/eval/instruction/generic_merge.cpp
+++ b/eval/src/vespa/eval/instruction/generic_merge.cpp
@@ -17,37 +17,6 @@ namespace vespalib::eval::instruction {
using State = InterpretedFunction::State;
using Instruction = InterpretedFunction::Instruction;
-namespace {
-
-//-----------------------------------------------------------------------------
-
-struct MergeParam {
- const ValueType res_type;
- const join_fun_t function;
- const size_t num_mapped_dimensions;
- const size_t dense_subspace_size;
- std::vector<size_t> all_view_dims;
- const ValueBuilderFactory &factory;
- MergeParam(const ValueType &lhs_type, const ValueType &rhs_type,
- join_fun_t function_in, const ValueBuilderFactory &factory_in)
- : res_type(ValueType::join(lhs_type, rhs_type)),
- function(function_in),
- num_mapped_dimensions(lhs_type.count_mapped_dimensions()),
- dense_subspace_size(lhs_type.dense_subspace_size()),
- all_view_dims(num_mapped_dimensions),
- factory(factory_in)
- {
- assert(!res_type.is_error());
- assert(num_mapped_dimensions == rhs_type.count_mapped_dimensions());
- assert(num_mapped_dimensions == res_type.count_mapped_dimensions());
- assert(dense_subspace_size == rhs_type.dense_subspace_size());
- assert(dense_subspace_size == res_type.dense_subspace_size());
- for (size_t i = 0; i < num_mapped_dimensions; ++i) {
- all_view_dims[i] = i;
- }
- }
- ~MergeParam();
-};
MergeParam::~MergeParam() = default;
//-----------------------------------------------------------------------------
@@ -108,39 +77,14 @@ generic_mixed_merge(const Value &a, const Value &b,
return builder->build(std::move(builder));
}
-template <typename LCT, typename RCT, typename OCT, typename Fun>
-void my_mixed_merge_op(State &state, uint64_t param_in) {
- const auto &param = unwrap_param<MergeParam>(param_in);
- const Value &lhs = state.peek(1);
- const Value &rhs = state.peek(0);
- auto up = generic_mixed_merge<LCT, RCT, OCT, Fun>(lhs, rhs, param);
- auto &result = state.stash.create<std::unique_ptr<Value>>(std::move(up));
- const Value &result_ref = *(result.get());
- state.pop_pop_push(result_ref);
-};
+
+namespace {
template <typename LCT, typename RCT, typename OCT, typename Fun>
-void my_sparse_merge_op(State &state, uint64_t param_in) {
+void my_mixed_merge_op(State &state, uint64_t param_in) {
const auto &param = unwrap_param<MergeParam>(param_in);
const Value &lhs = state.peek(1);
const Value &rhs = state.peek(0);
- if (auto indexes = detect_type<FastValueIndex>(lhs.index(), rhs.index())) {
- auto lhs_cells = lhs.cells().typify<LCT>();
- auto rhs_cells = rhs.cells().typify<RCT>();
- if (lhs_cells.size() < rhs_cells.size()) {
- return state.pop_pop_push(
- FastValueIndex::sparse_only_merge<RCT,LCT,OCT,Fun>(
- param.res_type, Fun(param.function),
- indexes.get<1>(), indexes.get<0>(),
- rhs_cells, lhs_cells, state.stash));
- } else {
- return state.pop_pop_push(
- FastValueIndex::sparse_only_merge<LCT,RCT,OCT,Fun>(
- param.res_type, Fun(param.function),
- indexes.get<0>(), indexes.get<1>(),
- lhs_cells, rhs_cells, state.stash));
- }
- }
auto up = generic_mixed_merge<LCT, RCT, OCT, Fun>(lhs, rhs, param);
auto &result = state.stash.create<std::unique_ptr<Value>>(std::move(up));
const Value &result_ref = *(result.get());
@@ -148,10 +92,7 @@ void my_sparse_merge_op(State &state, uint64_t param_in) {
};
struct SelectGenericMergeOp {
- template <typename LCT, typename RCT, typename OCT, typename Fun> static auto invoke(const MergeParam &param) {
- if (param.dense_subspace_size == 1) {
- return my_sparse_merge_op<LCT,RCT,OCT,Fun>;
- }
+ template <typename LCT, typename RCT, typename OCT, typename Fun> static auto invoke() {
return my_mixed_merge_op<LCT,RCT,OCT,Fun>;
}
};
@@ -167,7 +108,7 @@ GenericMerge::make_instruction(const ValueType &lhs_type, const ValueType &rhs_t
const ValueBuilderFactory &factory, Stash &stash)
{
const auto &param = stash.create<MergeParam>(lhs_type, rhs_type, function, factory);
- auto fun = typify_invoke<4,MergeTypify,SelectGenericMergeOp>(lhs_type.cell_type(), rhs_type.cell_type(), param.res_type.cell_type(), function, param);
+ auto fun = typify_invoke<4,MergeTypify,SelectGenericMergeOp>(lhs_type.cell_type(), rhs_type.cell_type(), param.res_type.cell_type(), function);
return Instruction(fun, wrap_param<MergeParam>(param));
}
diff --git a/eval/src/vespa/eval/instruction/generic_merge.h b/eval/src/vespa/eval/instruction/generic_merge.h
index 2b2964366cc..0319f1a929f 100644
--- a/eval/src/vespa/eval/instruction/generic_merge.h
+++ b/eval/src/vespa/eval/instruction/generic_merge.h
@@ -6,6 +6,39 @@
namespace vespalib::eval::instruction {
+struct MergeParam {
+ const ValueType res_type;
+ const join_fun_t function;
+ const size_t num_mapped_dimensions;
+ const size_t dense_subspace_size;
+ std::vector<size_t> all_view_dims;
+ const ValueBuilderFactory &factory;
+ MergeParam(const ValueType &lhs_type, const ValueType &rhs_type,
+ join_fun_t function_in, const ValueBuilderFactory &factory_in)
+ : res_type(ValueType::join(lhs_type, rhs_type)),
+ function(function_in),
+ num_mapped_dimensions(lhs_type.count_mapped_dimensions()),
+ dense_subspace_size(lhs_type.dense_subspace_size()),
+ all_view_dims(num_mapped_dimensions),
+ factory(factory_in)
+ {
+ assert(!res_type.is_error());
+ assert(num_mapped_dimensions == rhs_type.count_mapped_dimensions());
+ assert(num_mapped_dimensions == res_type.count_mapped_dimensions());
+ assert(dense_subspace_size == rhs_type.dense_subspace_size());
+ assert(dense_subspace_size == res_type.dense_subspace_size());
+ for (size_t i = 0; i < num_mapped_dimensions; ++i) {
+ all_view_dims[i] = i;
+ }
+ }
+ ~MergeParam();
+};
+
+template <typename LCT, typename RCT, typename OCT, typename Fun>
+std::unique_ptr<Value>
+generic_mixed_merge(const Value &a, const Value &b,
+ const MergeParam &params);
+
struct GenericMerge {
static InterpretedFunction::Instruction
make_instruction(const ValueType &lhs_type, const ValueType &rhs_type,
diff --git a/eval/src/vespa/eval/instruction/sparse_merge_function.cpp b/eval/src/vespa/eval/instruction/sparse_merge_function.cpp
new file mode 100644
index 00000000000..924c4d69fe9
--- /dev/null
+++ b/eval/src/vespa/eval/instruction/sparse_merge_function.cpp
@@ -0,0 +1,146 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "sparse_merge_function.h"
+#include "generic_merge.h"
+#include <vespa/eval/eval/fast_value.hpp>
+#include <vespa/vespalib/util/typify.h>
+
+namespace vespalib::eval {
+
+using namespace tensor_function;
+using namespace operation;
+using namespace instruction;
+
+namespace {
+
+template <typename CT, bool single_dim, typename Fun>
+const Value& my_fast_sparse_merge(const FastAddrMap &a_map, const FastAddrMap &b_map,
+ const CT *a_cells, const CT *b_cells,
+ const MergeParam &params,
+ Stash &stash)
+{
+ Fun fun(params.function);
+ size_t guess_size = a_map.size() + b_map.size();
+ auto &result = stash.create<FastValue<CT,true>>(params.res_type, params.num_mapped_dimensions, 1u, guess_size);
+ if constexpr (single_dim) {
+ string_id cur_label;
+ ConstArrayRef<string_id> addr(&cur_label, 1);
+ const auto &a_labels = a_map.labels();
+ for (size_t i = 0; i < a_labels.size(); ++i) {
+ cur_label = a_labels[i];
+ result.add_mapping(addr, cur_label.hash());
+ result.my_cells.push_back_fast(a_cells[i]);
+ }
+ const auto &b_labels = b_map.labels();
+ for (size_t i = 0; i < b_labels.size(); ++i) {
+ cur_label = b_labels[i];
+ auto result_subspace = result.my_index.map.lookup_singledim(cur_label);
+ if (result_subspace == FastAddrMap::npos()) {
+ result.add_mapping(addr, cur_label.hash());
+ result.my_cells.push_back_fast(b_cells[i]);
+ } else {
+ CT *out_cell = result.my_cells.get(result_subspace);
+ out_cell[0] = fun(out_cell[0], b_cells[i]);
+ }
+ }
+ } else {
+ a_map.each_map_entry([&](auto lhs_subspace, auto hash)
+ {
+ result.add_mapping(a_map.get_addr(lhs_subspace), hash);
+ result.my_cells.push_back_fast(a_cells[lhs_subspace]);
+ });
+ b_map.each_map_entry([&](auto rhs_subspace, auto hash)
+ {
+ auto rhs_addr = b_map.get_addr(rhs_subspace);
+ auto result_subspace = result.my_index.map.lookup(rhs_addr, hash);
+ if (result_subspace == FastAddrMap::npos()) {
+ result.add_mapping(rhs_addr, hash);
+ result.my_cells.push_back_fast(b_cells[rhs_subspace]);
+ } else {
+ CT *out_cell = result.my_cells.get(result_subspace);
+ out_cell[0] = fun(out_cell[0], b_cells[rhs_subspace]);
+ }
+ });
+ }
+ return result;
+}
+
+template <typename CT, bool single_dim, typename Fun>
+void my_sparse_merge_op(InterpretedFunction::State &state, uint64_t param_in) {
+ const auto &param = unwrap_param<MergeParam>(param_in);
+ assert(param.dense_subspace_size == 1u);
+ const Value &a = state.peek(1);
+ const Value &b = state.peek(0);
+ const auto &a_idx = a.index();
+ const auto &b_idx = b.index();
+ if (__builtin_expect(are_fast(a_idx, b_idx), true)) {
+ auto a_cells = a.cells().typify<CT>();
+ auto b_cells = b.cells().typify<CT>();
+ const Value &v = my_fast_sparse_merge<CT,single_dim,Fun>(as_fast(a_idx).map, as_fast(b_idx).map,
+ a_cells.cbegin(), b_cells.cbegin(),
+ param, state.stash);
+ state.pop_pop_push(v);
+ } else {
+ auto up = generic_mixed_merge<CT,CT,CT,Fun>(a, b, param);
+ state.pop_pop_push(*state.stash.create<std::unique_ptr<Value>>(std::move(up)));
+ }
+}
+
+struct SelectSparseMergeOp {
+ template <typename CT, typename SINGLE_DIM, typename Fun>
+ static auto invoke() { return my_sparse_merge_op<CT,SINGLE_DIM::value,Fun>; }
+};
+
+using MyTypify = TypifyValue<TypifyCellType,TypifyBool,operation::TypifyOp2>;
+
+} // namespace <unnamed>
+
+SparseMergeFunction::SparseMergeFunction(const tensor_function::Merge &original)
+ : tensor_function::Merge(original.result_type(),
+ original.lhs(),
+ original.rhs(),
+ original.function())
+{
+ assert(compatible_types(result_type(), lhs().result_type(), rhs().result_type()));
+}
+
+InterpretedFunction::Instruction
+SparseMergeFunction::compile_self(const ValueBuilderFactory &factory, Stash &stash) const
+{
+ const auto &param = stash.create<MergeParam>(lhs().result_type(), rhs().result_type(),
+ function(), factory);
+ size_t num_dims = result_type().count_mapped_dimensions();
+ auto op = typify_invoke<3,MyTypify,SelectSparseMergeOp>(result_type().cell_type(),
+ num_dims == 1,
+ function());
+ return InterpretedFunction::Instruction(op, wrap_param<MergeParam>(param));
+}
+
+bool
+SparseMergeFunction::compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs)
+{
+ if ((lhs.cell_type() == rhs.cell_type())
+ && (lhs.count_mapped_dimensions() > 0)
+ && (lhs.dense_subspace_size() == 1))
+ {
+ assert(res == lhs);
+ assert(res == rhs);
+ return true;
+ }
+ return false;
+}
+
+const TensorFunction &
+SparseMergeFunction::optimize(const TensorFunction &expr, Stash &stash)
+{
+ if (auto merge = as<Merge>(expr)) {
+ const TensorFunction &lhs = merge->lhs();
+ const TensorFunction &rhs = merge->rhs();
+ if (compatible_types(expr.result_type(), lhs.result_type(), rhs.result_type())) {
+ return stash.create<SparseMergeFunction>(*merge);
+ }
+ }
+ return expr;
+}
+
+} // namespace
diff --git a/eval/src/vespa/eval/instruction/sparse_merge_function.h b/eval/src/vespa/eval/instruction/sparse_merge_function.h
new file mode 100644
index 00000000000..d2b26196ed6
--- /dev/null
+++ b/eval/src/vespa/eval/instruction/sparse_merge_function.h
@@ -0,0 +1,22 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <vespa/eval/eval/tensor_function.h>
+
+namespace vespalib::eval {
+
+/**
+ * Tensor function for merging two sparse tensors.
+ */
+class SparseMergeFunction : public tensor_function::Merge
+{
+public:
+ SparseMergeFunction(const tensor_function::Merge &original);
+ InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override;
+ bool result_is_mutable() const override { return true; }
+ static bool compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs);
+ static const TensorFunction &optimize(const TensorFunction &expr, Stash &stash);
+};
+
+} // namespace