diff options
11 files changed, 266 insertions, 150 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt index 22750a27f20..66270bb323b 100644 --- a/eval/CMakeLists.txt +++ b/eval/CMakeLists.txt @@ -66,6 +66,7 @@ vespa_define_module( src/tests/instruction/pow_as_map_optimizer src/tests/instruction/remove_trivial_dimension_optimizer src/tests/instruction/sparse_dot_product_function + src/tests/instruction/sparse_full_overlap_join_function src/tests/instruction/sparse_merge_function src/tests/instruction/sparse_no_overlap_join_function src/tests/instruction/sum_max_dot_product_function diff --git a/eval/src/tests/instruction/sparse_full_overlap_join_function/CMakeLists.txt b/eval/src/tests/instruction/sparse_full_overlap_join_function/CMakeLists.txt new file mode 100644 index 00000000000..54841140278 --- /dev/null +++ b/eval/src/tests/instruction/sparse_full_overlap_join_function/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(eval_sparse_full_overlap_join_function_test_app TEST + SOURCES + sparse_full_overlap_join_function_test.cpp + DEPENDS + vespaeval + GTest::GTest +) +vespa_add_test(NAME eval_sparse_full_overlap_join_function_test_app COMMAND eval_sparse_full_overlap_join_function_test_app) diff --git a/eval/src/tests/instruction/sparse_full_overlap_join_function/sparse_full_overlap_join_function_test.cpp b/eval/src/tests/instruction/sparse_full_overlap_join_function/sparse_full_overlap_join_function_test.cpp new file mode 100644 index 00000000000..e3001b17602 --- /dev/null +++ b/eval/src/tests/instruction/sparse_full_overlap_join_function/sparse_full_overlap_join_function_test.cpp @@ -0,0 +1,92 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/eval/eval/fast_value.h> +#include <vespa/eval/eval/simple_value.h> +#include <vespa/eval/instruction/sparse_full_overlap_join_function.h> +#include <vespa/eval/eval/test/eval_fixture.h> +#include <vespa/vespalib/gtest/gtest.h> + +using namespace vespalib::eval; +using namespace vespalib::eval::test; + +const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get(); +const ValueBuilderFactory &test_factory = SimpleValueBuilderFactory::get(); + +//----------------------------------------------------------------------------- + +EvalFixture::ParamRepo make_params() { + return EvalFixture::ParamRepo() + .add_variants("v1_a", GenSpec(3.0).map("a", 8, 1)) + .add_variants("v2_a", GenSpec(7.0).map("a", 4, 2)) + .add_variants("v2_a_trivial", GenSpec(7.0).map("a", 4, 2).idx("b", 1).idx("c", 1)) + .add_variants("v3_b", GenSpec(5.0).map("b", 4, 2)) + .add("m1_ab", GenSpec(3.0).map("a", 8, 1).map("b", 8, 1)) + .add("m2_ab", GenSpec(17.0).map("a", 4, 2).map("b", 4, 2)) + .add("m3_bc", GenSpec(11.0).map("b", 4, 2).map("c", 4, 2)) + .add("scalar", GenSpec(1.0)) + .add("dense_a", GenSpec().idx("a", 5)) + .add("mixed_ab", GenSpec().map("a", 5, 1).idx("b", 5)); +} +EvalFixture::ParamRepo param_repo = make_params(); + +void assert_optimized(const vespalib::string &expr) { + EvalFixture fast_fixture(prod_factory, expr, param_repo, true); + EvalFixture test_fixture(test_factory, expr, param_repo, true); + EvalFixture slow_fixture(prod_factory, expr, param_repo, false); + EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo)); + EXPECT_EQ(test_fixture.result(), EvalFixture::ref(expr, param_repo)); + EXPECT_EQ(slow_fixture.result(), EvalFixture::ref(expr, param_repo)); + EXPECT_EQ(fast_fixture.find_all<SparseFullOverlapJoinFunction>().size(), 1u); + EXPECT_EQ(test_fixture.find_all<SparseFullOverlapJoinFunction>().size(), 1u); + EXPECT_EQ(slow_fixture.find_all<SparseFullOverlapJoinFunction>().size(), 0u); +} + +void assert_not_optimized(const vespalib::string &expr) { + EvalFixture fast_fixture(prod_factory, expr, param_repo, true); + EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo)); + EXPECT_EQ(fast_fixture.find_all<SparseFullOverlapJoinFunction>().size(), 0u); +} + +//----------------------------------------------------------------------------- + +TEST(SparseFullOverlapJoin, expression_can_be_optimized) +{ + assert_optimized("v1_a-v2_a"); + assert_optimized("v2_a-v1_a"); + assert_optimized("join(v1_a,v2_a,f(x,y)(max(x,y)))"); +} + +TEST(SparseFullOverlapJoin, multi_dimensional_expression_can_be_optimized) +{ + assert_optimized("m1_ab-m2_ab"); + assert_optimized("m2_ab-m1_ab"); + assert_optimized("join(m1_ab,m2_ab,f(x,y)(max(x,y)))"); +} + +TEST(SparseFullOverlapJoin, trivial_dimensions_are_ignored) +{ + assert_optimized("v1_a*v2_a_trivial"); + assert_optimized("v2_a_trivial*v1_a"); +} + +TEST(SparseFullOverlapJoin, inappropriate_shapes_are_not_optimized) +{ + assert_not_optimized("v1_a*scalar"); + assert_not_optimized("v1_a*mixed_ab"); + assert_not_optimized("v1_a*v3_b"); + assert_not_optimized("v1_a*m1_ab"); + assert_not_optimized("m1_ab*m3_bc"); + assert_not_optimized("scalar*scalar"); + assert_not_optimized("dense_a*dense_a"); + assert_not_optimized("mixed_ab*mixed_ab"); +} + +TEST(SparseFullOverlapJoin, mixed_cell_types_are_not_optimized) +{ + assert_not_optimized("v1_a*v2_a_f"); + assert_not_optimized("v1_a_f*v2_a"); +} + +//----------------------------------------------------------------------------- + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/eval/src/vespa/eval/eval/fast_value.hpp b/eval/src/vespa/eval/eval/fast_value.hpp index d5cfc9c6368..33624fb920e 100644 --- a/eval/src/vespa/eval/eval/fast_value.hpp +++ b/eval/src/vespa/eval/eval/fast_value.hpp @@ -47,7 +47,7 @@ struct FastFilterView : public Value::Index::View { const FastAddrMap ↦ std::vector<size_t> match_dims; std::vector<size_t> extract_dims; - std::vector<string_id> query; + std::vector<string_id> query; size_t pos; bool is_match(ConstArrayRef<string_id> addr) const { @@ -141,12 +141,6 @@ struct FastValueIndex final : Value::Index { FastAddrMap map; FastValueIndex(size_t num_mapped_dims_in, const std::vector<string_id> &labels, size_t expected_subspaces_in) : map(num_mapped_dims_in, labels, expected_subspaces_in) {} - - template <typename LCT, typename RCT, typename OCT, typename Fun> - static const Value &sparse_full_overlap_join(const ValueType &res_type, const Fun &fun, - const FastValueIndex &lhs, const FastValueIndex &rhs, - ConstArrayRef<LCT> lhs_cells, ConstArrayRef<RCT> rhs_cells, Stash &stash); - size_t size() const override { return map.size(); } std::unique_ptr<View> create_view(const std::vector<size_t> &dims) const override; }; @@ -267,6 +261,10 @@ struct FastValue final : Value, ValueBuilder<T> { } my_index.map.add_mapping(hash); } + void add_singledim_mapping(string_id label) { + my_handles.push_back(label); + my_index.map.add_mapping(FastAddrMap::hash_label(label)); + } ArrayRef<T> add_subspace(ConstArrayRef<vespalib::stringref> addr) override { add_mapping(addr); return my_cells.add_cells(my_subspace_size); @@ -344,25 +342,4 @@ struct FastScalarBuilder final : ValueBuilder<T> { //----------------------------------------------------------------------------- -template <typename LCT, typename RCT, typename OCT, typename Fun> -const Value & -FastValueIndex::sparse_full_overlap_join(const ValueType &res_type, const Fun &fun, - const FastValueIndex &lhs, const FastValueIndex &rhs, - ConstArrayRef<LCT> lhs_cells, ConstArrayRef<RCT> rhs_cells, Stash &stash) -{ - auto &result = stash.create<FastValue<OCT,true>>(res_type, lhs.map.addr_size(), 1, lhs.map.size()); - lhs.map.each_map_entry([&](auto lhs_subspace, auto hash) { - auto lhs_addr = lhs.map.get_addr(lhs_subspace); - auto rhs_subspace = rhs.map.lookup(lhs_addr, hash); - if (rhs_subspace != FastAddrMap::npos()) { - result.add_mapping(lhs_addr, hash); - auto cell_value = fun(lhs_cells[lhs_subspace], rhs_cells[rhs_subspace]); - result.my_cells.push_back_fast(cell_value); - } - }); - return result; -} - -//----------------------------------------------------------------------------- - } diff --git a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp index aef49a2c75b..f1ce293b18c 100644 --- a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp +++ b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp @@ -8,6 +8,7 @@ #include <vespa/eval/instruction/sparse_dot_product_function.h> #include <vespa/eval/instruction/sparse_merge_function.h> #include <vespa/eval/instruction/sparse_no_overlap_join_function.h> +#include <vespa/eval/instruction/sparse_full_overlap_join_function.h> #include <vespa/eval/instruction/mixed_inner_product_function.h> #include <vespa/eval/instruction/sum_max_dot_product_function.h> #include <vespa/eval/instruction/dense_xw_product_function.h> @@ -76,6 +77,7 @@ const TensorFunction &optimize_for_factory(const ValueBuilderFactory &, const Te child.set(DenseSingleReduceFunction::optimize(child.get(), stash)); child.set(SparseMergeFunction::optimize(child.get(), stash)); child.set(SparseNoOverlapJoinFunction::optimize(child.get(), stash)); + child.set(SparseFullOverlapJoinFunction::optimize(child.get(), stash)); nodes.pop_back(); } } diff --git a/eval/src/vespa/eval/instruction/CMakeLists.txt b/eval/src/vespa/eval/instruction/CMakeLists.txt index 50f7dbe7005..97838f2adb9 100644 --- a/eval/src/vespa/eval/instruction/CMakeLists.txt +++ b/eval/src/vespa/eval/instruction/CMakeLists.txt @@ -33,6 +33,7 @@ vespa_add_library(eval_instruction OBJECT remove_trivial_dimension_optimizer.cpp replace_type_function.cpp sparse_dot_product_function.cpp + sparse_full_overlap_join_function.cpp sparse_merge_function.cpp sparse_no_overlap_join_function.cpp sum_max_dot_product_function.cpp diff --git a/eval/src/vespa/eval/instruction/detect_type.h b/eval/src/vespa/eval/instruction/detect_type.h deleted file mode 100644 index f1769fa15cc..00000000000 --- a/eval/src/vespa/eval/instruction/detect_type.h +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#include <typeindex> -#include <array> -#include <cstddef> - -#pragma once - -namespace vespalib::eval::instruction { - -/* - * Utilities for detecting implementation class by comparing - * typeindex(typeid(T)); for now these are local to this - * namespace, but we can consider moving them to a more - * common place (probably vespalib) if we see more use-cases. - */ - -/** - * Recognize a (const) instance of type T. This is cheaper than - * dynamic_cast, but requires the object to be exactly of class T. - * Returns a pointer to the object as T if recognized, nullptr - * otherwise. - **/ -template<typename T, typename U> -const T * -recognize_by_type_index(const U & object) -{ - if (std::type_index(typeid(object)) == std::type_index(typeid(T))) { - return static_cast<const T *>(&object); - } - return nullptr; -} - -/** - * Packs N recognized values into one object, used as return value - * from detect_type<T>. - * - * Use all_converted() or the equivalent bool cast operator to check - * if all objects were recognized. After this check is successful use - * get<0>(), get<1>() etc to get a reference to the objects. - **/ -template<typename T, size_t N> -class RecognizedValues -{ -private: - std::array<const T *, N> _pointers; -public: - RecognizedValues(std::array<const T *, N> && pointers) - : _pointers(std::move(pointers)) - {} - bool all_converted() const { - for (auto p : _pointers) { - if (p == nullptr) return false; - } - return true; - } - operator bool() const { return all_converted(); } - template<size_t idx> const T& get() const { - static_assert(idx < N); - return *_pointers[idx]; - } -}; - -/** - * For all arguments, detect if they have typeid(T), convert to T if - * possible, and return a RecognizedValues packing the converted - * values. - **/ -template<typename T, typename... Args> -RecognizedValues<T, sizeof...(Args)> -detect_type(const Args &... args) -{ - return RecognizedValues<T, sizeof...(Args)>({(recognize_by_type_index<T>(args))...}); -} - -} // namespace diff --git a/eval/src/vespa/eval/instruction/generic_join.cpp b/eval/src/vespa/eval/instruction/generic_join.cpp index 4b3755509c7..fb714bcf16e 100644 --- a/eval/src/vespa/eval/instruction/generic_join.cpp +++ b/eval/src/vespa/eval/instruction/generic_join.cpp @@ -1,9 +1,7 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "generic_join.h" -#include "detect_type.h" #include <vespa/eval/eval/inline_operation.h> -#include <vespa/eval/eval/fast_value.hpp> #include <vespa/eval/eval/wrap_param.h> #include <vespa/vespalib/util/overload.h> #include <vespa/vespalib/util/stash.h> @@ -69,43 +67,6 @@ void my_mixed_join_op(State &state, uint64_t param_in) { //----------------------------------------------------------------------------- -template <typename LCT, typename RCT, typename OCT, typename Fun> -void my_sparse_full_overlap_join_op(State &state, uint64_t param_in) { - const auto ¶m = unwrap_param<JoinParam>(param_in); - const Value &lhs = state.peek(1); - const Value &rhs = state.peek(0); - auto lhs_cells = lhs.cells().typify<LCT>(); - auto rhs_cells = rhs.cells().typify<RCT>(); - const Value::Index &lhs_index = lhs.index(); - const Value::Index &rhs_index = rhs.index(); - if (auto indexes = detect_type<FastValueIndex>(lhs_index, rhs_index)) { - const auto &lhs_fast = indexes.get<0>(); - const auto &rhs_fast = indexes.get<1>(); - return (rhs_fast.map.size() < lhs_fast.map.size()) - ? state.pop_pop_push(FastValueIndex::sparse_full_overlap_join<RCT,LCT,OCT,SwapArgs2<Fun>> - (param.res_type, SwapArgs2<Fun>(param.function), rhs_fast, lhs_fast, rhs_cells, lhs_cells, state.stash)) - : state.pop_pop_push(FastValueIndex::sparse_full_overlap_join<LCT,RCT,OCT,Fun> - (param.res_type, Fun(param.function), lhs_fast, rhs_fast, lhs_cells, rhs_cells, state.stash)); - } - Fun fun(param.function); - SparseJoinState sparse(param.sparse_plan, lhs_index, rhs_index); - auto builder = param.factory.create_transient_value_builder<OCT>(param.res_type, param.sparse_plan.sources.size(), param.dense_plan.out_size, sparse.first_index.size()); - auto outer = sparse.first_index.create_view({}); - auto inner = sparse.second_index.create_view(sparse.second_view_dims); - outer->lookup({}); - while (outer->next_result(sparse.first_address, sparse.first_subspace)) { - inner->lookup(sparse.address_overlap); - if (inner->next_result(sparse.second_only_address, sparse.second_subspace)) { - builder->add_subspace(sparse.full_address)[0] = fun(lhs_cells[sparse.lhs_subspace], rhs_cells[sparse.rhs_subspace]); - } - } - auto &result = state.stash.create<std::unique_ptr<Value>>(builder->build(std::move(builder))); - const Value &result_ref = *(result.get()); - state.pop_pop_push(result_ref); -}; - -//----------------------------------------------------------------------------- - template <typename LCT, typename RCT, typename OCT, typename Fun, bool forward_lhs> void my_mixed_dense_join_op(State &state, uint64_t param_in) { const auto ¶m = unwrap_param<JoinParam>(param_in); @@ -175,11 +136,6 @@ struct SelectGenericJoinOp { if (param.sparse_plan.should_forward_rhs_index()) { return my_mixed_dense_join_op<LCT,RCT,OCT,Fun,false>; } - if ((param.dense_plan.out_size == 1) && - (param.sparse_plan.sources.size() == param.sparse_plan.lhs_overlap.size())) - { - return my_sparse_full_overlap_join_op<LCT,RCT,OCT,Fun>; - } return my_mixed_join_op<LCT,RCT,OCT,Fun>; } }; diff --git a/eval/src/vespa/eval/instruction/generic_merge.cpp b/eval/src/vespa/eval/instruction/generic_merge.cpp index 9b098db7763..434ab308c3c 100644 --- a/eval/src/vespa/eval/instruction/generic_merge.cpp +++ b/eval/src/vespa/eval/instruction/generic_merge.cpp @@ -1,9 +1,7 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include "detect_type.h" #include "generic_merge.h" #include <vespa/eval/eval/inline_operation.h> -#include <vespa/eval/eval/fast_value.hpp> #include <vespa/eval/eval/wrap_param.h> #include <vespa/vespalib/util/stash.h> #include <vespa/vespalib/util/typify.h> diff --git a/eval/src/vespa/eval/instruction/sparse_full_overlap_join_function.cpp b/eval/src/vespa/eval/instruction/sparse_full_overlap_join_function.cpp new file mode 100644 index 00000000000..480af3315b1 --- /dev/null +++ b/eval/src/vespa/eval/instruction/sparse_full_overlap_join_function.cpp @@ -0,0 +1,134 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "sparse_full_overlap_join_function.h" +#include "generic_join.h" +#include <vespa/eval/eval/fast_value.hpp> +#include <vespa/vespalib/util/typify.h> + +namespace vespalib::eval { + +using namespace tensor_function; +using namespace operation; +using namespace instruction; + +namespace { + +template <typename CT, typename Fun, bool single_dim> +const Value &my_fast_sparse_full_overlap_join(const FastAddrMap &lhs_map, const FastAddrMap &rhs_map, + const CT *lhs_cells, const CT *rhs_cells, + const JoinParam ¶m, Stash &stash) +{ + Fun fun(param.function); + auto &result = stash.create<FastValue<CT,true>>(param.res_type, lhs_map.addr_size(), 1, lhs_map.size()); + if constexpr (single_dim) { + const auto &labels = lhs_map.labels(); + for (size_t i = 0; i < labels.size(); ++i) { + auto rhs_subspace = rhs_map.lookup_singledim(labels[i]); + if (rhs_subspace != FastAddrMap::npos()) { + result.add_singledim_mapping(labels[i]); + auto cell_value = fun(lhs_cells[i], rhs_cells[rhs_subspace]); + result.my_cells.push_back_fast(cell_value); + } + } + } else { + lhs_map.each_map_entry([&](auto lhs_subspace, auto hash) { + auto lhs_addr = lhs_map.get_addr(lhs_subspace); + auto rhs_subspace = rhs_map.lookup(lhs_addr, hash); + if (rhs_subspace != FastAddrMap::npos()) { + result.add_mapping(lhs_addr, hash); + auto cell_value = fun(lhs_cells[lhs_subspace], rhs_cells[rhs_subspace]); + result.my_cells.push_back_fast(cell_value); + } + }); + } + return result; +} + +template <typename CT, typename Fun, bool single_dim> +const Value &my_fast_sparse_full_overlap_join_dispatch(const FastAddrMap &lhs_map, const FastAddrMap &rhs_map, + const CT *lhs_cells, const CT *rhs_cells, + const JoinParam ¶m, Stash &stash) +{ + return (rhs_map.size() < lhs_map.size()) + ? my_fast_sparse_full_overlap_join<CT,SwapArgs2<Fun>,single_dim>(rhs_map, lhs_map, rhs_cells, lhs_cells, param, stash) + : my_fast_sparse_full_overlap_join<CT,Fun,single_dim>(lhs_map, rhs_map, lhs_cells, rhs_cells, param, stash); +} + +template <typename CT, typename Fun, bool single_dim> +void my_sparse_full_overlap_join_op(InterpretedFunction::State &state, uint64_t param_in) { + const auto ¶m = unwrap_param<JoinParam>(param_in); + const Value &lhs = state.peek(1); + const Value &rhs = state.peek(0); + const auto &lhs_idx = lhs.index(); + const auto &rhs_idx = rhs.index(); + if (__builtin_expect(are_fast(lhs_idx, rhs_idx), true)) { + const Value &res = my_fast_sparse_full_overlap_join_dispatch<CT,Fun,single_dim>(as_fast(lhs_idx).map, as_fast(rhs_idx).map, + lhs.cells().typify<CT>().cbegin(), rhs.cells().typify<CT>().cbegin(), param, state.stash); + state.pop_pop_push(res); + } else { + auto res = generic_mixed_join<CT,CT,CT,Fun>(lhs, rhs, param); + state.pop_pop_push(*state.stash.create<std::unique_ptr<Value>>(std::move(res))); + } +} + +struct SelectSparseFullOverlapJoinOp { + template <typename CT, typename Fun, typename SINGLE_DIM> + static auto invoke() { return my_sparse_full_overlap_join_op<CT,Fun,SINGLE_DIM::value>; } +}; + +using MyTypify = TypifyValue<TypifyCellType,operation::TypifyOp2,TypifyBool>; + +bool is_sparse_like(const ValueType &type) { + return ((type.count_mapped_dimensions() > 0) && (type.dense_subspace_size() == 1)); +} + +} // namespace <unnamed> + +SparseFullOverlapJoinFunction::SparseFullOverlapJoinFunction(const tensor_function::Join &original) + : tensor_function::Join(original.result_type(), + original.lhs(), + original.rhs(), + original.function()) +{ + assert(compatible_types(result_type(), lhs().result_type(), rhs().result_type())); +} + +InterpretedFunction::Instruction +SparseFullOverlapJoinFunction::compile_self(const ValueBuilderFactory &factory, Stash &stash) const +{ + const auto ¶m = stash.create<JoinParam>(lhs().result_type(), rhs().result_type(), function(), factory); + assert(param.res_type == result_type()); + bool single_dim = (result_type().count_mapped_dimensions() == 1); + auto op = typify_invoke<3,MyTypify,SelectSparseFullOverlapJoinOp>(result_type().cell_type(), function(), single_dim); + return InterpretedFunction::Instruction(op, wrap_param<JoinParam>(param)); +} + +bool +SparseFullOverlapJoinFunction::compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs) +{ + if ((lhs.cell_type() == rhs.cell_type()) && + is_sparse_like(lhs) && is_sparse_like(rhs) && + (res.count_mapped_dimensions() == lhs.count_mapped_dimensions()) && + (res.count_mapped_dimensions() == rhs.count_mapped_dimensions())) + { + assert(is_sparse_like(res)); + assert(res.cell_type() == lhs.cell_type()); + return true; + } + return false; +} + +const TensorFunction & +SparseFullOverlapJoinFunction::optimize(const TensorFunction &expr, Stash &stash) +{ + if (auto join = as<Join>(expr)) { + const TensorFunction &lhs = join->lhs(); + const TensorFunction &rhs = join->rhs(); + if (compatible_types(expr.result_type(), lhs.result_type(), rhs.result_type())) { + return stash.create<SparseFullOverlapJoinFunction>(*join); + } + } + return expr; +} + +} // namespace diff --git a/eval/src/vespa/eval/instruction/sparse_full_overlap_join_function.h b/eval/src/vespa/eval/instruction/sparse_full_overlap_join_function.h new file mode 100644 index 00000000000..13d35065997 --- /dev/null +++ b/eval/src/vespa/eval/instruction/sparse_full_overlap_join_function.h @@ -0,0 +1,22 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/eval/eval/tensor_function.h> + +namespace vespalib::eval { + +/** + * Tensor function for joining tensors with full sparse overlap. + */ +class SparseFullOverlapJoinFunction : public tensor_function::Join +{ +public: + SparseFullOverlapJoinFunction(const tensor_function::Join &original); + InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override; + bool result_is_mutable() const override { return true; } + static bool compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs); + static const TensorFunction &optimize(const TensorFunction &expr, Stash &stash); +}; + +} // namespace |