summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2021-02-16 14:30:31 +0000
committerHåvard Pettersen <havardpe@oath.com>2021-02-17 12:58:34 +0000
commit2e1a333ce13a465ced7137ad12d070575031fd65 (patch)
tree92f028d91965874d4d574fdf69b6733b90f383fd /eval
parent37fd0e320d0f78c0442fa73fdaddaec33c916d28 (diff)
sparse full overlap join
Diffstat (limited to 'eval')
-rw-r--r--eval/CMakeLists.txt1
-rw-r--r--eval/src/tests/instruction/sparse_full_overlap_join_function/CMakeLists.txt9
-rw-r--r--eval/src/tests/instruction/sparse_full_overlap_join_function/sparse_full_overlap_join_function_test.cpp92
-rw-r--r--eval/src/vespa/eval/eval/fast_value.hpp33
-rw-r--r--eval/src/vespa/eval/eval/optimize_tensor_function.cpp2
-rw-r--r--eval/src/vespa/eval/instruction/CMakeLists.txt1
-rw-r--r--eval/src/vespa/eval/instruction/detect_type.h76
-rw-r--r--eval/src/vespa/eval/instruction/generic_join.cpp44
-rw-r--r--eval/src/vespa/eval/instruction/generic_merge.cpp2
-rw-r--r--eval/src/vespa/eval/instruction/sparse_full_overlap_join_function.cpp134
-rw-r--r--eval/src/vespa/eval/instruction/sparse_full_overlap_join_function.h22
11 files changed, 266 insertions, 150 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt
index 22750a27f20..66270bb323b 100644
--- a/eval/CMakeLists.txt
+++ b/eval/CMakeLists.txt
@@ -66,6 +66,7 @@ vespa_define_module(
src/tests/instruction/pow_as_map_optimizer
src/tests/instruction/remove_trivial_dimension_optimizer
src/tests/instruction/sparse_dot_product_function
+ src/tests/instruction/sparse_full_overlap_join_function
src/tests/instruction/sparse_merge_function
src/tests/instruction/sparse_no_overlap_join_function
src/tests/instruction/sum_max_dot_product_function
diff --git a/eval/src/tests/instruction/sparse_full_overlap_join_function/CMakeLists.txt b/eval/src/tests/instruction/sparse_full_overlap_join_function/CMakeLists.txt
new file mode 100644
index 00000000000..54841140278
--- /dev/null
+++ b/eval/src/tests/instruction/sparse_full_overlap_join_function/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(eval_sparse_full_overlap_join_function_test_app TEST
+ SOURCES
+ sparse_full_overlap_join_function_test.cpp
+ DEPENDS
+ vespaeval
+ GTest::GTest
+)
+vespa_add_test(NAME eval_sparse_full_overlap_join_function_test_app COMMAND eval_sparse_full_overlap_join_function_test_app)
diff --git a/eval/src/tests/instruction/sparse_full_overlap_join_function/sparse_full_overlap_join_function_test.cpp b/eval/src/tests/instruction/sparse_full_overlap_join_function/sparse_full_overlap_join_function_test.cpp
new file mode 100644
index 00000000000..e3001b17602
--- /dev/null
+++ b/eval/src/tests/instruction/sparse_full_overlap_join_function/sparse_full_overlap_join_function_test.cpp
@@ -0,0 +1,92 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/eval/eval/fast_value.h>
+#include <vespa/eval/eval/simple_value.h>
+#include <vespa/eval/instruction/sparse_full_overlap_join_function.h>
+#include <vespa/eval/eval/test/eval_fixture.h>
+#include <vespa/vespalib/gtest/gtest.h>
+
+using namespace vespalib::eval;
+using namespace vespalib::eval::test;
+
+const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();
+const ValueBuilderFactory &test_factory = SimpleValueBuilderFactory::get();
+
+//-----------------------------------------------------------------------------
+
+EvalFixture::ParamRepo make_params() {
+ return EvalFixture::ParamRepo()
+ .add_variants("v1_a", GenSpec(3.0).map("a", 8, 1))
+ .add_variants("v2_a", GenSpec(7.0).map("a", 4, 2))
+ .add_variants("v2_a_trivial", GenSpec(7.0).map("a", 4, 2).idx("b", 1).idx("c", 1))
+ .add_variants("v3_b", GenSpec(5.0).map("b", 4, 2))
+ .add("m1_ab", GenSpec(3.0).map("a", 8, 1).map("b", 8, 1))
+ .add("m2_ab", GenSpec(17.0).map("a", 4, 2).map("b", 4, 2))
+ .add("m3_bc", GenSpec(11.0).map("b", 4, 2).map("c", 4, 2))
+ .add("scalar", GenSpec(1.0))
+ .add("dense_a", GenSpec().idx("a", 5))
+ .add("mixed_ab", GenSpec().map("a", 5, 1).idx("b", 5));
+}
+EvalFixture::ParamRepo param_repo = make_params();
+
+void assert_optimized(const vespalib::string &expr) {
+ EvalFixture fast_fixture(prod_factory, expr, param_repo, true);
+ EvalFixture test_fixture(test_factory, expr, param_repo, true);
+ EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
+ EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQ(test_fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQ(slow_fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQ(fast_fixture.find_all<SparseFullOverlapJoinFunction>().size(), 1u);
+ EXPECT_EQ(test_fixture.find_all<SparseFullOverlapJoinFunction>().size(), 1u);
+ EXPECT_EQ(slow_fixture.find_all<SparseFullOverlapJoinFunction>().size(), 0u);
+}
+
+void assert_not_optimized(const vespalib::string &expr) {
+ EvalFixture fast_fixture(prod_factory, expr, param_repo, true);
+ EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQ(fast_fixture.find_all<SparseFullOverlapJoinFunction>().size(), 0u);
+}
+
+//-----------------------------------------------------------------------------
+
+TEST(SparseFullOverlapJoin, expression_can_be_optimized)
+{
+ assert_optimized("v1_a-v2_a");
+ assert_optimized("v2_a-v1_a");
+ assert_optimized("join(v1_a,v2_a,f(x,y)(max(x,y)))");
+}
+
+TEST(SparseFullOverlapJoin, multi_dimensional_expression_can_be_optimized)
+{
+ assert_optimized("m1_ab-m2_ab");
+ assert_optimized("m2_ab-m1_ab");
+ assert_optimized("join(m1_ab,m2_ab,f(x,y)(max(x,y)))");
+}
+
+TEST(SparseFullOverlapJoin, trivial_dimensions_are_ignored)
+{
+ assert_optimized("v1_a*v2_a_trivial");
+ assert_optimized("v2_a_trivial*v1_a");
+}
+
+TEST(SparseFullOverlapJoin, inappropriate_shapes_are_not_optimized)
+{
+ assert_not_optimized("v1_a*scalar");
+ assert_not_optimized("v1_a*mixed_ab");
+ assert_not_optimized("v1_a*v3_b");
+ assert_not_optimized("v1_a*m1_ab");
+ assert_not_optimized("m1_ab*m3_bc");
+ assert_not_optimized("scalar*scalar");
+ assert_not_optimized("dense_a*dense_a");
+ assert_not_optimized("mixed_ab*mixed_ab");
+}
+
+TEST(SparseFullOverlapJoin, mixed_cell_types_are_not_optimized)
+{
+ assert_not_optimized("v1_a*v2_a_f");
+ assert_not_optimized("v1_a_f*v2_a");
+}
+
+//-----------------------------------------------------------------------------
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/eval/src/vespa/eval/eval/fast_value.hpp b/eval/src/vespa/eval/eval/fast_value.hpp
index d5cfc9c6368..33624fb920e 100644
--- a/eval/src/vespa/eval/eval/fast_value.hpp
+++ b/eval/src/vespa/eval/eval/fast_value.hpp
@@ -47,7 +47,7 @@ struct FastFilterView : public Value::Index::View {
const FastAddrMap &map;
std::vector<size_t> match_dims;
std::vector<size_t> extract_dims;
- std::vector<string_id> query;
+ std::vector<string_id> query;
size_t pos;
bool is_match(ConstArrayRef<string_id> addr) const {
@@ -141,12 +141,6 @@ struct FastValueIndex final : Value::Index {
FastAddrMap map;
FastValueIndex(size_t num_mapped_dims_in, const std::vector<string_id> &labels, size_t expected_subspaces_in)
: map(num_mapped_dims_in, labels, expected_subspaces_in) {}
-
- template <typename LCT, typename RCT, typename OCT, typename Fun>
- static const Value &sparse_full_overlap_join(const ValueType &res_type, const Fun &fun,
- const FastValueIndex &lhs, const FastValueIndex &rhs,
- ConstArrayRef<LCT> lhs_cells, ConstArrayRef<RCT> rhs_cells, Stash &stash);
-
size_t size() const override { return map.size(); }
std::unique_ptr<View> create_view(const std::vector<size_t> &dims) const override;
};
@@ -267,6 +261,10 @@ struct FastValue final : Value, ValueBuilder<T> {
}
my_index.map.add_mapping(hash);
}
+ void add_singledim_mapping(string_id label) {
+ my_handles.push_back(label);
+ my_index.map.add_mapping(FastAddrMap::hash_label(label));
+ }
ArrayRef<T> add_subspace(ConstArrayRef<vespalib::stringref> addr) override {
add_mapping(addr);
return my_cells.add_cells(my_subspace_size);
@@ -344,25 +342,4 @@ struct FastScalarBuilder final : ValueBuilder<T> {
//-----------------------------------------------------------------------------
-template <typename LCT, typename RCT, typename OCT, typename Fun>
-const Value &
-FastValueIndex::sparse_full_overlap_join(const ValueType &res_type, const Fun &fun,
- const FastValueIndex &lhs, const FastValueIndex &rhs,
- ConstArrayRef<LCT> lhs_cells, ConstArrayRef<RCT> rhs_cells, Stash &stash)
-{
- auto &result = stash.create<FastValue<OCT,true>>(res_type, lhs.map.addr_size(), 1, lhs.map.size());
- lhs.map.each_map_entry([&](auto lhs_subspace, auto hash) {
- auto lhs_addr = lhs.map.get_addr(lhs_subspace);
- auto rhs_subspace = rhs.map.lookup(lhs_addr, hash);
- if (rhs_subspace != FastAddrMap::npos()) {
- result.add_mapping(lhs_addr, hash);
- auto cell_value = fun(lhs_cells[lhs_subspace], rhs_cells[rhs_subspace]);
- result.my_cells.push_back_fast(cell_value);
- }
- });
- return result;
-}
-
-//-----------------------------------------------------------------------------
-
}
diff --git a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
index aef49a2c75b..f1ce293b18c 100644
--- a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
@@ -8,6 +8,7 @@
#include <vespa/eval/instruction/sparse_dot_product_function.h>
#include <vespa/eval/instruction/sparse_merge_function.h>
#include <vespa/eval/instruction/sparse_no_overlap_join_function.h>
+#include <vespa/eval/instruction/sparse_full_overlap_join_function.h>
#include <vespa/eval/instruction/mixed_inner_product_function.h>
#include <vespa/eval/instruction/sum_max_dot_product_function.h>
#include <vespa/eval/instruction/dense_xw_product_function.h>
@@ -76,6 +77,7 @@ const TensorFunction &optimize_for_factory(const ValueBuilderFactory &, const Te
child.set(DenseSingleReduceFunction::optimize(child.get(), stash));
child.set(SparseMergeFunction::optimize(child.get(), stash));
child.set(SparseNoOverlapJoinFunction::optimize(child.get(), stash));
+ child.set(SparseFullOverlapJoinFunction::optimize(child.get(), stash));
nodes.pop_back();
}
}
diff --git a/eval/src/vespa/eval/instruction/CMakeLists.txt b/eval/src/vespa/eval/instruction/CMakeLists.txt
index 50f7dbe7005..97838f2adb9 100644
--- a/eval/src/vespa/eval/instruction/CMakeLists.txt
+++ b/eval/src/vespa/eval/instruction/CMakeLists.txt
@@ -33,6 +33,7 @@ vespa_add_library(eval_instruction OBJECT
remove_trivial_dimension_optimizer.cpp
replace_type_function.cpp
sparse_dot_product_function.cpp
+ sparse_full_overlap_join_function.cpp
sparse_merge_function.cpp
sparse_no_overlap_join_function.cpp
sum_max_dot_product_function.cpp
diff --git a/eval/src/vespa/eval/instruction/detect_type.h b/eval/src/vespa/eval/instruction/detect_type.h
deleted file mode 100644
index f1769fa15cc..00000000000
--- a/eval/src/vespa/eval/instruction/detect_type.h
+++ /dev/null
@@ -1,76 +0,0 @@
-// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-#include <typeindex>
-#include <array>
-#include <cstddef>
-
-#pragma once
-
-namespace vespalib::eval::instruction {
-
-/*
- * Utilities for detecting implementation class by comparing
- * typeindex(typeid(T)); for now these are local to this
- * namespace, but we can consider moving them to a more
- * common place (probably vespalib) if we see more use-cases.
- */
-
-/**
- * Recognize a (const) instance of type T. This is cheaper than
- * dynamic_cast, but requires the object to be exactly of class T.
- * Returns a pointer to the object as T if recognized, nullptr
- * otherwise.
- **/
-template<typename T, typename U>
-const T *
-recognize_by_type_index(const U & object)
-{
- if (std::type_index(typeid(object)) == std::type_index(typeid(T))) {
- return static_cast<const T *>(&object);
- }
- return nullptr;
-}
-
-/**
- * Packs N recognized values into one object, used as return value
- * from detect_type<T>.
- *
- * Use all_converted() or the equivalent bool cast operator to check
- * if all objects were recognized. After this check is successful use
- * get<0>(), get<1>() etc to get a reference to the objects.
- **/
-template<typename T, size_t N>
-class RecognizedValues
-{
-private:
- std::array<const T *, N> _pointers;
-public:
- RecognizedValues(std::array<const T *, N> && pointers)
- : _pointers(std::move(pointers))
- {}
- bool all_converted() const {
- for (auto p : _pointers) {
- if (p == nullptr) return false;
- }
- return true;
- }
- operator bool() const { return all_converted(); }
- template<size_t idx> const T& get() const {
- static_assert(idx < N);
- return *_pointers[idx];
- }
-};
-
-/**
- * For all arguments, detect if they have typeid(T), convert to T if
- * possible, and return a RecognizedValues packing the converted
- * values.
- **/
-template<typename T, typename... Args>
-RecognizedValues<T, sizeof...(Args)>
-detect_type(const Args &... args)
-{
- return RecognizedValues<T, sizeof...(Args)>({(recognize_by_type_index<T>(args))...});
-}
-
-} // namespace
diff --git a/eval/src/vespa/eval/instruction/generic_join.cpp b/eval/src/vespa/eval/instruction/generic_join.cpp
index 4b3755509c7..fb714bcf16e 100644
--- a/eval/src/vespa/eval/instruction/generic_join.cpp
+++ b/eval/src/vespa/eval/instruction/generic_join.cpp
@@ -1,9 +1,7 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "generic_join.h"
-#include "detect_type.h"
#include <vespa/eval/eval/inline_operation.h>
-#include <vespa/eval/eval/fast_value.hpp>
#include <vespa/eval/eval/wrap_param.h>
#include <vespa/vespalib/util/overload.h>
#include <vespa/vespalib/util/stash.h>
@@ -69,43 +67,6 @@ void my_mixed_join_op(State &state, uint64_t param_in) {
//-----------------------------------------------------------------------------
-template <typename LCT, typename RCT, typename OCT, typename Fun>
-void my_sparse_full_overlap_join_op(State &state, uint64_t param_in) {
- const auto &param = unwrap_param<JoinParam>(param_in);
- const Value &lhs = state.peek(1);
- const Value &rhs = state.peek(0);
- auto lhs_cells = lhs.cells().typify<LCT>();
- auto rhs_cells = rhs.cells().typify<RCT>();
- const Value::Index &lhs_index = lhs.index();
- const Value::Index &rhs_index = rhs.index();
- if (auto indexes = detect_type<FastValueIndex>(lhs_index, rhs_index)) {
- const auto &lhs_fast = indexes.get<0>();
- const auto &rhs_fast = indexes.get<1>();
- return (rhs_fast.map.size() < lhs_fast.map.size())
- ? state.pop_pop_push(FastValueIndex::sparse_full_overlap_join<RCT,LCT,OCT,SwapArgs2<Fun>>
- (param.res_type, SwapArgs2<Fun>(param.function), rhs_fast, lhs_fast, rhs_cells, lhs_cells, state.stash))
- : state.pop_pop_push(FastValueIndex::sparse_full_overlap_join<LCT,RCT,OCT,Fun>
- (param.res_type, Fun(param.function), lhs_fast, rhs_fast, lhs_cells, rhs_cells, state.stash));
- }
- Fun fun(param.function);
- SparseJoinState sparse(param.sparse_plan, lhs_index, rhs_index);
- auto builder = param.factory.create_transient_value_builder<OCT>(param.res_type, param.sparse_plan.sources.size(), param.dense_plan.out_size, sparse.first_index.size());
- auto outer = sparse.first_index.create_view({});
- auto inner = sparse.second_index.create_view(sparse.second_view_dims);
- outer->lookup({});
- while (outer->next_result(sparse.first_address, sparse.first_subspace)) {
- inner->lookup(sparse.address_overlap);
- if (inner->next_result(sparse.second_only_address, sparse.second_subspace)) {
- builder->add_subspace(sparse.full_address)[0] = fun(lhs_cells[sparse.lhs_subspace], rhs_cells[sparse.rhs_subspace]);
- }
- }
- auto &result = state.stash.create<std::unique_ptr<Value>>(builder->build(std::move(builder)));
- const Value &result_ref = *(result.get());
- state.pop_pop_push(result_ref);
-};
-
-//-----------------------------------------------------------------------------
-
template <typename LCT, typename RCT, typename OCT, typename Fun, bool forward_lhs>
void my_mixed_dense_join_op(State &state, uint64_t param_in) {
const auto &param = unwrap_param<JoinParam>(param_in);
@@ -175,11 +136,6 @@ struct SelectGenericJoinOp {
if (param.sparse_plan.should_forward_rhs_index()) {
return my_mixed_dense_join_op<LCT,RCT,OCT,Fun,false>;
}
- if ((param.dense_plan.out_size == 1) &&
- (param.sparse_plan.sources.size() == param.sparse_plan.lhs_overlap.size()))
- {
- return my_sparse_full_overlap_join_op<LCT,RCT,OCT,Fun>;
- }
return my_mixed_join_op<LCT,RCT,OCT,Fun>;
}
};
diff --git a/eval/src/vespa/eval/instruction/generic_merge.cpp b/eval/src/vespa/eval/instruction/generic_merge.cpp
index 9b098db7763..434ab308c3c 100644
--- a/eval/src/vespa/eval/instruction/generic_merge.cpp
+++ b/eval/src/vespa/eval/instruction/generic_merge.cpp
@@ -1,9 +1,7 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-#include "detect_type.h"
#include "generic_merge.h"
#include <vespa/eval/eval/inline_operation.h>
-#include <vespa/eval/eval/fast_value.hpp>
#include <vespa/eval/eval/wrap_param.h>
#include <vespa/vespalib/util/stash.h>
#include <vespa/vespalib/util/typify.h>
diff --git a/eval/src/vespa/eval/instruction/sparse_full_overlap_join_function.cpp b/eval/src/vespa/eval/instruction/sparse_full_overlap_join_function.cpp
new file mode 100644
index 00000000000..480af3315b1
--- /dev/null
+++ b/eval/src/vespa/eval/instruction/sparse_full_overlap_join_function.cpp
@@ -0,0 +1,134 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "sparse_full_overlap_join_function.h"
+#include "generic_join.h"
+#include <vespa/eval/eval/fast_value.hpp>
+#include <vespa/vespalib/util/typify.h>
+
+namespace vespalib::eval {
+
+using namespace tensor_function;
+using namespace operation;
+using namespace instruction;
+
+namespace {
+
+template <typename CT, typename Fun, bool single_dim>
+const Value &my_fast_sparse_full_overlap_join(const FastAddrMap &lhs_map, const FastAddrMap &rhs_map,
+ const CT *lhs_cells, const CT *rhs_cells,
+ const JoinParam &param, Stash &stash)
+{
+ Fun fun(param.function);
+ auto &result = stash.create<FastValue<CT,true>>(param.res_type, lhs_map.addr_size(), 1, lhs_map.size());
+ if constexpr (single_dim) {
+ const auto &labels = lhs_map.labels();
+ for (size_t i = 0; i < labels.size(); ++i) {
+ auto rhs_subspace = rhs_map.lookup_singledim(labels[i]);
+ if (rhs_subspace != FastAddrMap::npos()) {
+ result.add_singledim_mapping(labels[i]);
+ auto cell_value = fun(lhs_cells[i], rhs_cells[rhs_subspace]);
+ result.my_cells.push_back_fast(cell_value);
+ }
+ }
+ } else {
+ lhs_map.each_map_entry([&](auto lhs_subspace, auto hash) {
+ auto lhs_addr = lhs_map.get_addr(lhs_subspace);
+ auto rhs_subspace = rhs_map.lookup(lhs_addr, hash);
+ if (rhs_subspace != FastAddrMap::npos()) {
+ result.add_mapping(lhs_addr, hash);
+ auto cell_value = fun(lhs_cells[lhs_subspace], rhs_cells[rhs_subspace]);
+ result.my_cells.push_back_fast(cell_value);
+ }
+ });
+ }
+ return result;
+}
+
+template <typename CT, typename Fun, bool single_dim>
+const Value &my_fast_sparse_full_overlap_join_dispatch(const FastAddrMap &lhs_map, const FastAddrMap &rhs_map,
+ const CT *lhs_cells, const CT *rhs_cells,
+ const JoinParam &param, Stash &stash)
+{
+ return (rhs_map.size() < lhs_map.size())
+ ? my_fast_sparse_full_overlap_join<CT,SwapArgs2<Fun>,single_dim>(rhs_map, lhs_map, rhs_cells, lhs_cells, param, stash)
+ : my_fast_sparse_full_overlap_join<CT,Fun,single_dim>(lhs_map, rhs_map, lhs_cells, rhs_cells, param, stash);
+}
+
+template <typename CT, typename Fun, bool single_dim>
+void my_sparse_full_overlap_join_op(InterpretedFunction::State &state, uint64_t param_in) {
+ const auto &param = unwrap_param<JoinParam>(param_in);
+ const Value &lhs = state.peek(1);
+ const Value &rhs = state.peek(0);
+ const auto &lhs_idx = lhs.index();
+ const auto &rhs_idx = rhs.index();
+ if (__builtin_expect(are_fast(lhs_idx, rhs_idx), true)) {
+ const Value &res = my_fast_sparse_full_overlap_join_dispatch<CT,Fun,single_dim>(as_fast(lhs_idx).map, as_fast(rhs_idx).map,
+ lhs.cells().typify<CT>().cbegin(), rhs.cells().typify<CT>().cbegin(), param, state.stash);
+ state.pop_pop_push(res);
+ } else {
+ auto res = generic_mixed_join<CT,CT,CT,Fun>(lhs, rhs, param);
+ state.pop_pop_push(*state.stash.create<std::unique_ptr<Value>>(std::move(res)));
+ }
+}
+
+struct SelectSparseFullOverlapJoinOp {
+ template <typename CT, typename Fun, typename SINGLE_DIM>
+ static auto invoke() { return my_sparse_full_overlap_join_op<CT,Fun,SINGLE_DIM::value>; }
+};
+
+using MyTypify = TypifyValue<TypifyCellType,operation::TypifyOp2,TypifyBool>;
+
+bool is_sparse_like(const ValueType &type) {
+ return ((type.count_mapped_dimensions() > 0) && (type.dense_subspace_size() == 1));
+}
+
+} // namespace <unnamed>
+
+SparseFullOverlapJoinFunction::SparseFullOverlapJoinFunction(const tensor_function::Join &original)
+ : tensor_function::Join(original.result_type(),
+ original.lhs(),
+ original.rhs(),
+ original.function())
+{
+ assert(compatible_types(result_type(), lhs().result_type(), rhs().result_type()));
+}
+
+InterpretedFunction::Instruction
+SparseFullOverlapJoinFunction::compile_self(const ValueBuilderFactory &factory, Stash &stash) const
+{
+ const auto &param = stash.create<JoinParam>(lhs().result_type(), rhs().result_type(), function(), factory);
+ assert(param.res_type == result_type());
+ bool single_dim = (result_type().count_mapped_dimensions() == 1);
+ auto op = typify_invoke<3,MyTypify,SelectSparseFullOverlapJoinOp>(result_type().cell_type(), function(), single_dim);
+ return InterpretedFunction::Instruction(op, wrap_param<JoinParam>(param));
+}
+
+bool
+SparseFullOverlapJoinFunction::compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs)
+{
+ if ((lhs.cell_type() == rhs.cell_type()) &&
+ is_sparse_like(lhs) && is_sparse_like(rhs) &&
+ (res.count_mapped_dimensions() == lhs.count_mapped_dimensions()) &&
+ (res.count_mapped_dimensions() == rhs.count_mapped_dimensions()))
+ {
+ assert(is_sparse_like(res));
+ assert(res.cell_type() == lhs.cell_type());
+ return true;
+ }
+ return false;
+}
+
+const TensorFunction &
+SparseFullOverlapJoinFunction::optimize(const TensorFunction &expr, Stash &stash)
+{
+ if (auto join = as<Join>(expr)) {
+ const TensorFunction &lhs = join->lhs();
+ const TensorFunction &rhs = join->rhs();
+ if (compatible_types(expr.result_type(), lhs.result_type(), rhs.result_type())) {
+ return stash.create<SparseFullOverlapJoinFunction>(*join);
+ }
+ }
+ return expr;
+}
+
+} // namespace
diff --git a/eval/src/vespa/eval/instruction/sparse_full_overlap_join_function.h b/eval/src/vespa/eval/instruction/sparse_full_overlap_join_function.h
new file mode 100644
index 00000000000..13d35065997
--- /dev/null
+++ b/eval/src/vespa/eval/instruction/sparse_full_overlap_join_function.h
@@ -0,0 +1,22 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <vespa/eval/eval/tensor_function.h>
+
+namespace vespalib::eval {
+
+/**
+ * Tensor function for joining tensors with full sparse overlap.
+ */
+class SparseFullOverlapJoinFunction : public tensor_function::Join
+{
+public:
+ SparseFullOverlapJoinFunction(const tensor_function::Join &original);
+ InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override;
+ bool result_is_mutable() const override { return true; }
+ static bool compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs);
+ static const TensorFunction &optimize(const TensorFunction &expr, Stash &stash);
+};
+
+} // namespace