summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/instruction/generic_concat/generic_concat_test.cpp11
-rw-r--r--eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp6
-rw-r--r--eval/src/vespa/eval/instruction/generic_concat.cpp40
3 files changed, 57 insertions, 0 deletions
diff --git a/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp b/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp
index cfecdb97aa0..bc8ea84744f 100644
--- a/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp
+++ b/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp
@@ -52,6 +52,17 @@ std::vector<Layout> concat_layouts = {
float_cells({x({"a","b","c"})}), float_cells({z({"foo","bar","baz"})}),
{x({"a","b","c"})}, {x({"a","b","c"}),z({"foo","bar","baz"})},
{x({"a","b"}),z({"foo","bar","baz"})}, {x({"a","b","c"}),z({"foo","bar"})},
+ {x({"a","b","c"}),y(3)}, {y(2)},
+ {x({"a","b","c"}),y(3)}, {z(5)},
+ {x({"a","b","c"}),y(3)}, {y(2),z(5)},
+ {x({"a","b","c"}),y(3)}, {y(2)},
+ {x({"a","b","c"}),y(3),z(5)}, {z(5)},
+ {y(2)}, {x({"a","b","c"}),y(3)},
+ {z(5)}, {x({"a","b","c"}),y(3)},
+ {y(2),z(5)}, {x({"a","b","c"}),y(3)},
+ {y(2)}, {x({"a","b","c"}),y(3)},
+ {z(5)}, {x({"a","b","c"}),y(3),z(5)},
+ {y(2),z(5)}, {x({"a","b","c"}),y(3),z(5)},
{y(2),x({"a","b","c"})}, {y(3),x({"b","c","d"})},
{y(2),x({"a","b"})}, {y(3),z({"c","d"})}
};
diff --git a/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp b/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp
index aa1da07bc91..89bddfda933 100644
--- a/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp
+++ b/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp
@@ -768,6 +768,12 @@ TEST(SparseConcat, small_vectors) {
benchmark_concat("small sparse concat", lhs, rhs, "y");
}
+TEST(MixedConcat, mixed_vs_dense) {
+ auto lhs = make_cube(D::idx("a", 16), D::idx("b", 16), D::map("c", 16, 1), 1.0);
+ auto rhs = make_matrix(D::idx("a", 16), D::idx("b", 16), 2.0);
+ benchmark_concat("mixed dense concat a", lhs, rhs, "a");
+}
+
TEST(MixedConcat, large_mixed_a) {
auto lhs = make_cube(D::idx("a", 16), D::idx("b", 16), D::map("c", 16, 1), 1.0);
auto rhs = make_cube(D::idx("a", 16), D::idx("b", 16), D::map("c", 16, 2), 2.0);
diff --git a/eval/src/vespa/eval/instruction/generic_concat.cpp b/eval/src/vespa/eval/instruction/generic_concat.cpp
index 5d8ab7187c0..989e921b8fb 100644
--- a/eval/src/vespa/eval/instruction/generic_concat.cpp
+++ b/eval/src/vespa/eval/instruction/generic_concat.cpp
@@ -87,6 +87,40 @@ void my_generic_concat_op(State &state, uint64_t param_in) {
state.pop_pop_push(result_ref);
}
+template <typename LCT, typename RCT, typename OCT, bool forward_lhs>
+void my_mixed_dense_concat_op(State &state, uint64_t param_in) {
+ const auto &param = unwrap_param<ConcatParam>(param_in);
+ const DenseConcatPlan &dense_plan = param.dense_plan;
+ auto lhs_cells = state.peek(1).cells().typify<LCT>();
+ auto rhs_cells = state.peek(0).cells().typify<RCT>();
+ const auto &index = state.peek(forward_lhs ? 1 : 0).index();
+ size_t num_subspaces = index.size();
+ size_t num_out_cells = dense_plan.output_size * num_subspaces;
+ ArrayRef<OCT> out_cells = state.stash.create_uninitialized_array<OCT>(num_out_cells);
+ OCT *dst = out_cells.begin();
+ const LCT *lhs = lhs_cells.begin();
+ const RCT *rhs = rhs_cells.begin();
+ auto copy_left = [&](size_t in_idx, size_t out_idx) { dst[out_idx] = lhs[in_idx]; };
+ auto copy_right = [&](size_t in_idx, size_t out_idx) { dst[out_idx] = rhs[in_idx]; };
+ for (size_t i = 0; i < num_subspaces; ++i) {
+ dense_plan.left.execute(0, 0, copy_left);
+ dense_plan.right.execute(0, dense_plan.right_offset, copy_right);
+ if (forward_lhs) {
+ lhs += dense_plan.left.input_size;
+ } else {
+ rhs += dense_plan.right.input_size;
+ }
+ dst += dense_plan.output_size;
+ }
+ if (forward_lhs) {
+ assert(lhs == lhs_cells.end());
+ } else {
+ assert(rhs == rhs_cells.end());
+ }
+ assert(dst == out_cells.end());
+ state.pop_pop_push(state.stash.create<ValueView>(param.res_type, index, TypedCells(out_cells)));
+}
+
template <typename LCT, typename RCT, typename OCT>
void my_dense_simple_concat_op(State &state, uint64_t param_in) {
const auto &param = unwrap_param<ConcatParam>(param_in);
@@ -116,6 +150,12 @@ struct SelectGenericConcatOp {
return my_dense_simple_concat_op<LCT, RCT, OCT>;
}
}
+ if (param.sparse_plan.should_forward_lhs_index()) {
+ return my_mixed_dense_concat_op<LCT, RCT, OCT, true>;
+ }
+ if (param.sparse_plan.should_forward_rhs_index()) {
+ return my_mixed_dense_concat_op<LCT, RCT, OCT, false>;
+ }
return my_generic_concat_op<LCT, RCT, OCT>;
}
};