diff options
author | Håvard Pettersen <havardpe@oath.com> | 2021-01-14 15:28:22 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2021-01-14 15:33:53 +0000 |
commit | 972e6b8052da175c5db224f19c1480b6238447f6 (patch) | |
tree | 252c1d038bd1ec5d82d250fd1e3c885e34c3482a /eval | |
parent | c1084399767a5aaf504229ac9c485c696a0832d3 (diff) |
forward index when joining a mixed tensor with a dense one
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/tests/instruction/generic_join/generic_join_test.cpp | 4 | ||||
-rw-r--r-- | eval/src/vespa/eval/instruction/generic_join.cpp | 59 | ||||
-rw-r--r-- | eval/src/vespa/eval/instruction/generic_join.h | 2 |
3 files changed, 64 insertions, 1 deletions
diff --git a/eval/src/tests/instruction/generic_join/generic_join_test.cpp b/eval/src/tests/instruction/generic_join/generic_join_test.cpp index f4046b3d059..7083f4d7eb7 100644 --- a/eval/src/tests/instruction/generic_join/generic_join_test.cpp +++ b/eval/src/tests/instruction/generic_join/generic_join_test.cpp @@ -40,7 +40,9 @@ std::vector<Layout> join_layouts = { {x({"a","b","c"}),y(5)}, {y(5),z({"i","j","k","l"})}, float_cells({x({"a","b","c"}),y(5)}), {y(5),z({"i","j","k","l"})}, {x({"a","b","c"}),y(5)}, float_cells({y(5),z({"i","j","k","l"})}), - float_cells({x({"a","b","c"}),y(5)}), float_cells({y(5),z({"i","j","k","l"})}) + float_cells({x({"a","b","c"}),y(5)}), float_cells({y(5),z({"i","j","k","l"})}), + {x({"a","b","c"}),y(5)}, float_cells({y(5)}), + {y(5)}, float_cells({x({"a","b","c"}),y(5)}) }; bool join_address(const TensorSpec::Address &a, const TensorSpec::Address &b, TensorSpec::Address &addr) { diff --git a/eval/src/vespa/eval/instruction/generic_join.cpp b/eval/src/vespa/eval/instruction/generic_join.cpp index e0dc0feea28..abe29b8228c 100644 --- a/eval/src/vespa/eval/instruction/generic_join.cpp +++ b/eval/src/vespa/eval/instruction/generic_join.cpp @@ -148,6 +148,37 @@ void my_sparse_full_overlap_join_op(State &state, uint64_t param_in) { //----------------------------------------------------------------------------- +template <typename LCT, typename RCT, typename OCT, typename Fun, bool forward_lhs> +void my_mixed_dense_join_op(State &state, uint64_t param_in) { + const auto ¶m = unwrap_param<JoinParam>(param_in); + Fun fun(param.function); + auto lhs_cells = state.peek(1).cells().typify<LCT>(); + auto rhs_cells = state.peek(0).cells().typify<RCT>(); + const auto &index = state.peek(forward_lhs ? 1 : 0).index(); + size_t num_subspaces = index.size(); + ArrayRef<OCT> out_cells = state.stash.create_uninitialized_array<OCT>(param.dense_plan.out_size * num_subspaces); + OCT *dst = out_cells.begin(); + const LCT *lhs = lhs_cells.begin(); + const RCT *rhs = rhs_cells.begin(); + auto join_cells = [&](size_t lhs_idx, size_t rhs_idx) { *dst++ = fun(lhs[lhs_idx], rhs[rhs_idx]); }; + for (size_t i = 0; i < num_subspaces; ++i) { + param.dense_plan.execute(0, 0, join_cells); + if (forward_lhs) { + lhs += param.dense_plan.lhs_size; + } else { + rhs += param.dense_plan.rhs_size; + } + } + if (forward_lhs) { + assert(lhs == lhs_cells.end()); + } else { + assert(rhs == rhs_cells.end()); + } + state.pop_pop_push(state.stash.create<ValueView>(param.res_type, index, TypedCells(out_cells))); +}; + +//----------------------------------------------------------------------------- + template <typename LCT, typename RCT, typename OCT, typename Fun> void my_dense_join_op(State &state, uint64_t param_in) { const auto ¶m = unwrap_param<JoinParam>(param_in); @@ -180,6 +211,12 @@ struct SelectGenericJoinOp { if (param.sparse_plan.sources.empty()) { return my_dense_join_op<LCT,RCT,OCT,Fun>; } + if (param.sparse_plan.should_forward_lhs_index()) { + return my_mixed_dense_join_op<LCT,RCT,OCT,Fun,true>; + } + if (param.sparse_plan.should_forward_rhs_index()) { + return my_mixed_dense_join_op<LCT,RCT,OCT,Fun,false>; + } if ((param.dense_plan.out_size == 1) && (param.sparse_plan.sources.size() == param.sparse_plan.lhs_overlap.size())) { @@ -271,6 +308,28 @@ SparseJoinPlan::SparseJoinPlan(const ValueType &lhs_type, const ValueType &rhs_t [](const auto &a, const auto &b){ return (a.name < b.name); }); } +bool +SparseJoinPlan::should_forward_lhs_index() const +{ + for (Source src: sources) { + if (src != Source::LHS) { + return false; + } + } + return (sources.size() > 0); +} + +bool +SparseJoinPlan::should_forward_rhs_index() const +{ + for (Source src: sources) { + if (src != Source::RHS) { + return false; + } + } + return (sources.size() > 0); +} + SparseJoinPlan::~SparseJoinPlan() = default; //----------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/instruction/generic_join.h b/eval/src/vespa/eval/instruction/generic_join.h index aaee64f7b25..1fcfcf416cc 100644 --- a/eval/src/vespa/eval/instruction/generic_join.h +++ b/eval/src/vespa/eval/instruction/generic_join.h @@ -55,6 +55,8 @@ struct SparseJoinPlan { std::vector<Source> sources; std::vector<size_t> lhs_overlap; std::vector<size_t> rhs_overlap; + bool should_forward_lhs_index() const; + bool should_forward_rhs_index() const; SparseJoinPlan(const ValueType &lhs_type, const ValueType &rhs_type); ~SparseJoinPlan(); }; |