summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2021-01-14 15:28:22 +0000
committerHåvard Pettersen <havardpe@oath.com>2021-01-14 15:33:53 +0000
commit972e6b8052da175c5db224f19c1480b6238447f6 (patch)
tree252c1d038bd1ec5d82d250fd1e3c885e34c3482a /eval
parentc1084399767a5aaf504229ac9c485c696a0832d3 (diff)
forward index when joining a mixed tensor with a dense one
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/instruction/generic_join/generic_join_test.cpp4
-rw-r--r--eval/src/vespa/eval/instruction/generic_join.cpp59
-rw-r--r--eval/src/vespa/eval/instruction/generic_join.h2
3 files changed, 64 insertions, 1 deletions
diff --git a/eval/src/tests/instruction/generic_join/generic_join_test.cpp b/eval/src/tests/instruction/generic_join/generic_join_test.cpp
index f4046b3d059..7083f4d7eb7 100644
--- a/eval/src/tests/instruction/generic_join/generic_join_test.cpp
+++ b/eval/src/tests/instruction/generic_join/generic_join_test.cpp
@@ -40,7 +40,9 @@ std::vector<Layout> join_layouts = {
{x({"a","b","c"}),y(5)}, {y(5),z({"i","j","k","l"})},
float_cells({x({"a","b","c"}),y(5)}), {y(5),z({"i","j","k","l"})},
{x({"a","b","c"}),y(5)}, float_cells({y(5),z({"i","j","k","l"})}),
- float_cells({x({"a","b","c"}),y(5)}), float_cells({y(5),z({"i","j","k","l"})})
+ float_cells({x({"a","b","c"}),y(5)}), float_cells({y(5),z({"i","j","k","l"})}),
+ {x({"a","b","c"}),y(5)}, float_cells({y(5)}),
+ {y(5)}, float_cells({x({"a","b","c"}),y(5)})
};
bool join_address(const TensorSpec::Address &a, const TensorSpec::Address &b, TensorSpec::Address &addr) {
diff --git a/eval/src/vespa/eval/instruction/generic_join.cpp b/eval/src/vespa/eval/instruction/generic_join.cpp
index e0dc0feea28..abe29b8228c 100644
--- a/eval/src/vespa/eval/instruction/generic_join.cpp
+++ b/eval/src/vespa/eval/instruction/generic_join.cpp
@@ -148,6 +148,37 @@ void my_sparse_full_overlap_join_op(State &state, uint64_t param_in) {
//-----------------------------------------------------------------------------
+template <typename LCT, typename RCT, typename OCT, typename Fun, bool forward_lhs>
+void my_mixed_dense_join_op(State &state, uint64_t param_in) {
+ const auto &param = unwrap_param<JoinParam>(param_in);
+ Fun fun(param.function);
+ auto lhs_cells = state.peek(1).cells().typify<LCT>();
+ auto rhs_cells = state.peek(0).cells().typify<RCT>();
+ const auto &index = state.peek(forward_lhs ? 1 : 0).index();
+ size_t num_subspaces = index.size();
+ ArrayRef<OCT> out_cells = state.stash.create_uninitialized_array<OCT>(param.dense_plan.out_size * num_subspaces);
+ OCT *dst = out_cells.begin();
+ const LCT *lhs = lhs_cells.begin();
+ const RCT *rhs = rhs_cells.begin();
+ auto join_cells = [&](size_t lhs_idx, size_t rhs_idx) { *dst++ = fun(lhs[lhs_idx], rhs[rhs_idx]); };
+ for (size_t i = 0; i < num_subspaces; ++i) {
+ param.dense_plan.execute(0, 0, join_cells);
+ if (forward_lhs) {
+ lhs += param.dense_plan.lhs_size;
+ } else {
+ rhs += param.dense_plan.rhs_size;
+ }
+ }
+ if (forward_lhs) {
+ assert(lhs == lhs_cells.end());
+ } else {
+ assert(rhs == rhs_cells.end());
+ }
+ state.pop_pop_push(state.stash.create<ValueView>(param.res_type, index, TypedCells(out_cells)));
+};
+
+//-----------------------------------------------------------------------------
+
template <typename LCT, typename RCT, typename OCT, typename Fun>
void my_dense_join_op(State &state, uint64_t param_in) {
const auto &param = unwrap_param<JoinParam>(param_in);
@@ -180,6 +211,12 @@ struct SelectGenericJoinOp {
if (param.sparse_plan.sources.empty()) {
return my_dense_join_op<LCT,RCT,OCT,Fun>;
}
+ if (param.sparse_plan.should_forward_lhs_index()) {
+ return my_mixed_dense_join_op<LCT,RCT,OCT,Fun,true>;
+ }
+ if (param.sparse_plan.should_forward_rhs_index()) {
+ return my_mixed_dense_join_op<LCT,RCT,OCT,Fun,false>;
+ }
if ((param.dense_plan.out_size == 1) &&
(param.sparse_plan.sources.size() == param.sparse_plan.lhs_overlap.size()))
{
@@ -271,6 +308,28 @@ SparseJoinPlan::SparseJoinPlan(const ValueType &lhs_type, const ValueType &rhs_t
[](const auto &a, const auto &b){ return (a.name < b.name); });
}
+bool
+SparseJoinPlan::should_forward_lhs_index() const
+{
+ for (Source src: sources) {
+ if (src != Source::LHS) {
+ return false;
+ }
+ }
+ return (sources.size() > 0);
+}
+
+bool
+SparseJoinPlan::should_forward_rhs_index() const
+{
+ for (Source src: sources) {
+ if (src != Source::RHS) {
+ return false;
+ }
+ }
+ return (sources.size() > 0);
+}
+
SparseJoinPlan::~SparseJoinPlan() = default;
//-----------------------------------------------------------------------------
diff --git a/eval/src/vespa/eval/instruction/generic_join.h b/eval/src/vespa/eval/instruction/generic_join.h
index aaee64f7b25..1fcfcf416cc 100644
--- a/eval/src/vespa/eval/instruction/generic_join.h
+++ b/eval/src/vespa/eval/instruction/generic_join.h
@@ -55,6 +55,8 @@ struct SparseJoinPlan {
std::vector<Source> sources;
std::vector<size_t> lhs_overlap;
std::vector<size_t> rhs_overlap;
+ bool should_forward_lhs_index() const;
+ bool should_forward_rhs_index() const;
SparseJoinPlan(const ValueType &lhs_type, const ValueType &rhs_type);
~SparseJoinPlan();
};