diff options
3 files changed, 57 insertions, 0 deletions
diff --git a/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp b/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp index cfecdb97aa0..bc8ea84744f 100644 --- a/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp +++ b/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp @@ -52,6 +52,17 @@ std::vector<Layout> concat_layouts = { float_cells({x({"a","b","c"})}), float_cells({z({"foo","bar","baz"})}), {x({"a","b","c"})}, {x({"a","b","c"}),z({"foo","bar","baz"})}, {x({"a","b"}),z({"foo","bar","baz"})}, {x({"a","b","c"}),z({"foo","bar"})}, + {x({"a","b","c"}),y(3)}, {y(2)}, + {x({"a","b","c"}),y(3)}, {z(5)}, + {x({"a","b","c"}),y(3)}, {y(2),z(5)}, + {x({"a","b","c"}),y(3)}, {y(2)}, + {x({"a","b","c"}),y(3),z(5)}, {z(5)}, + {y(2)}, {x({"a","b","c"}),y(3)}, + {z(5)}, {x({"a","b","c"}),y(3)}, + {y(2),z(5)}, {x({"a","b","c"}),y(3)}, + {y(2)}, {x({"a","b","c"}),y(3)}, + {z(5)}, {x({"a","b","c"}),y(3),z(5)}, + {y(2),z(5)}, {x({"a","b","c"}),y(3),z(5)}, {y(2),x({"a","b","c"})}, {y(3),x({"b","c","d"})}, {y(2),x({"a","b"})}, {y(3),z({"c","d"})} }; diff --git a/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp b/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp index aa1da07bc91..89bddfda933 100644 --- a/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp +++ b/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp @@ -768,6 +768,12 @@ TEST(SparseConcat, small_vectors) { benchmark_concat("small sparse concat", lhs, rhs, "y"); } +TEST(MixedConcat, mixed_vs_dense) { + auto lhs = make_cube(D::idx("a", 16), D::idx("b", 16), D::map("c", 16, 1), 1.0); + auto rhs = make_matrix(D::idx("a", 16), D::idx("b", 16), 2.0); + benchmark_concat("mixed dense concat a", lhs, rhs, "a"); +} + TEST(MixedConcat, large_mixed_a) { auto lhs = make_cube(D::idx("a", 16), D::idx("b", 16), D::map("c", 16, 1), 1.0); auto rhs = make_cube(D::idx("a", 16), D::idx("b", 16), D::map("c", 16, 2), 2.0); diff --git a/eval/src/vespa/eval/instruction/generic_concat.cpp b/eval/src/vespa/eval/instruction/generic_concat.cpp index 5d8ab7187c0..989e921b8fb 100644 --- a/eval/src/vespa/eval/instruction/generic_concat.cpp +++ b/eval/src/vespa/eval/instruction/generic_concat.cpp @@ -87,6 +87,40 @@ void my_generic_concat_op(State &state, uint64_t param_in) { state.pop_pop_push(result_ref); } +template <typename LCT, typename RCT, typename OCT, bool forward_lhs> +void my_mixed_dense_concat_op(State &state, uint64_t param_in) { + const auto ¶m = unwrap_param<ConcatParam>(param_in); + const DenseConcatPlan &dense_plan = param.dense_plan; + auto lhs_cells = state.peek(1).cells().typify<LCT>(); + auto rhs_cells = state.peek(0).cells().typify<RCT>(); + const auto &index = state.peek(forward_lhs ? 1 : 0).index(); + size_t num_subspaces = index.size(); + size_t num_out_cells = dense_plan.output_size * num_subspaces; + ArrayRef<OCT> out_cells = state.stash.create_uninitialized_array<OCT>(num_out_cells); + OCT *dst = out_cells.begin(); + const LCT *lhs = lhs_cells.begin(); + const RCT *rhs = rhs_cells.begin(); + auto copy_left = [&](size_t in_idx, size_t out_idx) { dst[out_idx] = lhs[in_idx]; }; + auto copy_right = [&](size_t in_idx, size_t out_idx) { dst[out_idx] = rhs[in_idx]; }; + for (size_t i = 0; i < num_subspaces; ++i) { + dense_plan.left.execute(0, 0, copy_left); + dense_plan.right.execute(0, dense_plan.right_offset, copy_right); + if (forward_lhs) { + lhs += dense_plan.left.input_size; + } else { + rhs += dense_plan.right.input_size; + } + dst += dense_plan.output_size; + } + if (forward_lhs) { + assert(lhs == lhs_cells.end()); + } else { + assert(rhs == rhs_cells.end()); + } + assert(dst == out_cells.end()); + state.pop_pop_push(state.stash.create<ValueView>(param.res_type, index, TypedCells(out_cells))); +} + template <typename LCT, typename RCT, typename OCT> void my_dense_simple_concat_op(State &state, uint64_t param_in) { const auto ¶m = unwrap_param<ConcatParam>(param_in); @@ -116,6 +150,12 @@ struct SelectGenericConcatOp { return my_dense_simple_concat_op<LCT, RCT, OCT>; } } + if (param.sparse_plan.should_forward_lhs_index()) { + return my_mixed_dense_concat_op<LCT, RCT, OCT, true>; + } + if (param.sparse_plan.should_forward_rhs_index()) { + return my_mixed_dense_concat_op<LCT, RCT, OCT, false>; + } return my_generic_concat_op<LCT, RCT, OCT>; } }; |