diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2020-10-02 15:08:40 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-10-02 15:08:40 +0200 |
commit | c67283d48af378f25cf1bd3ed8e578d0e529813f (patch) | |
tree | 6fb45410c5b117ba4dbe1cdf511dd30e57fc5292 | |
parent | 6867d23640ee3a1b1d037dd3a1316ca323701721 (diff) | |
parent | d2a77e011a670b051f3dc3d7a70ed16d65ed6597 (diff) |
Merge pull request #14680 from vespa-engine/arnej/add-generic-mixed-merge
Arnej/add generic mixed merge.
-rw-r--r-- | eval/CMakeLists.txt | 1 | ||||
-rw-r--r-- | eval/src/tests/instruction/generic_merge/CMakeLists.txt | 9 | ||||
-rw-r--r-- | eval/src/tests/instruction/generic_merge/generic_merge_test.cpp | 82 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/test/tensor_model.hpp | 2 | ||||
-rw-r--r-- | eval/src/vespa/eval/instruction/CMakeLists.txt | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/instruction/generic_merge.cpp | 147 | ||||
-rw-r--r-- | eval/src/vespa/eval/instruction/generic_merge.h | 15 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp | 27 |
8 files changed, 270 insertions, 14 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt index 03a19b8c2c4..1872ee1410a 100644 --- a/eval/CMakeLists.txt +++ b/eval/CMakeLists.txt @@ -33,6 +33,7 @@ vespa_define_module( src/tests/eval/value_type src/tests/gp/ponder_nov2017 src/tests/instruction/generic_join + src/tests/instruction/generic_merge src/tests/instruction/generic_rename src/tests/tensor/default_value_builder_factory src/tests/tensor/dense_add_dimension_optimizer diff --git a/eval/src/tests/instruction/generic_merge/CMakeLists.txt b/eval/src/tests/instruction/generic_merge/CMakeLists.txt new file mode 100644 index 00000000000..154b04cb32f --- /dev/null +++ b/eval/src/tests/instruction/generic_merge/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(eval_generic_merge_test_app TEST + SOURCES + generic_merge_test.cpp + DEPENDS + vespaeval + GTest::GTest +) +vespa_add_test(NAME eval_generic_merge_test_app COMMAND eval_generic_merge_test_app) diff --git a/eval/src/tests/instruction/generic_merge/generic_merge_test.cpp b/eval/src/tests/instruction/generic_merge/generic_merge_test.cpp new file mode 100644 index 00000000000..501d5410b87 --- /dev/null +++ b/eval/src/tests/instruction/generic_merge/generic_merge_test.cpp @@ -0,0 +1,82 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/eval/eval/simple_value.h> +#include <vespa/eval/eval/value_codec.h> +#include <vespa/eval/instruction/generic_merge.h> +#include <vespa/eval/eval/interpreted_function.h> +#include <vespa/eval/eval/test/tensor_model.hpp> +#include <vespa/vespalib/util/stringfmt.h> +#include <vespa/vespalib/gtest/gtest.h> +#include <optional> + +using namespace vespalib; +using namespace vespalib::eval; +using namespace vespalib::eval::instruction; +using namespace vespalib::eval::test; + +using vespalib::make_string_short::fmt; + +std::vector<Layout> merge_layouts = { + {}, {}, + {x(5)}, {x(5)}, + {x(3),y(5)}, {x(3),y(5)}, + float_cells({x(3),y(5)}), {x(3),y(5)}, + {x(3),y(5)}, float_cells({x(3),y(5)}), + {x({"a","b","c"})}, {x({"a","b","c"})}, + {x({"a","b","c"})}, {x({"c","d","e"})}, + {x({"a","c","e"})}, {x({"b","c","d"})}, + {x({"b","c","d"})}, {x({"a","c","e"})}, + {x({"a","b","c"})}, {x({"c","d"})}, + {x({"a","b"}),y({"foo","bar","baz"})}, {x({"b","c"}),y({"any","foo","bar"})}, + {x(3),y({"foo", "bar"})}, {x(3),y({"baz", "bar"})}, + {x({"a","b","c"}),y(5)}, {x({"b","c","d"}),y(5)} +}; + + +TensorSpec reference_merge(const TensorSpec &a, const TensorSpec &b, join_fun_t fun) { + ValueType res_type = ValueType::merge(ValueType::from_spec(a.type()), + ValueType::from_spec(b.type())); + EXPECT_FALSE(res_type.is_error()); + TensorSpec result(res_type.to_spec()); + for (const auto &cell: a.cells()) { + auto other = b.cells().find(cell.first); + if (other == b.cells().end()) { + result.add(cell.first, cell.second); + } else { + result.add(cell.first, fun(cell.second, other->second)); + } + } + for (const auto &cell: b.cells()) { + auto other = a.cells().find(cell.first); + if (other == a.cells().end()) { + result.add(cell.first, cell.second); + } + } + return result; +} + +TensorSpec perform_generic_merge(const TensorSpec &a, const TensorSpec &b, join_fun_t fun) { + Stash stash; + const auto &factory = SimpleValueBuilderFactory::get(); + auto lhs = value_from_spec(a, factory); + auto rhs = value_from_spec(b, factory); + auto my_op = GenericMerge::make_instruction(lhs->type(), rhs->type(), fun, factory, stash); + InterpretedFunction::EvalSingle single(my_op); + return spec_from_value(single.eval(std::vector<Value::CREF>({*lhs, *rhs}))); +} + +TEST(GenericMergeTest, generic_merge_works_for_simple_values) { + ASSERT_TRUE((merge_layouts.size() % 2) == 0); + for (size_t i = 0; i < merge_layouts.size(); i += 2) { + TensorSpec lhs = spec(merge_layouts[i], N()); + TensorSpec rhs = spec(merge_layouts[i + 1], Div16(N())); + SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str())); + for (auto fun: {operation::Add::f, operation::Mul::f, operation::Sub::f, operation::Max::f}) { + auto expect = reference_merge(lhs, rhs, fun); + auto actual = perform_generic_merge(lhs, rhs, fun); + EXPECT_EQ(actual, expect); + } + } +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/eval/src/vespa/eval/eval/test/tensor_model.hpp b/eval/src/vespa/eval/eval/test/tensor_model.hpp index 42f0dc7e996..4e4ef60aaee 100644 --- a/eval/src/vespa/eval/eval/test/tensor_model.hpp +++ b/eval/src/vespa/eval/eval/test/tensor_model.hpp @@ -38,7 +38,7 @@ struct Div10 : Sequence { double operator[](size_t i) const override { return (seq[i] / 10.0); } }; -// Sequence of another sequence divided by 10 +// Sequence of another sequence divided by 16 struct Div16 : Sequence { const Sequence &seq; Div16(const Sequence &seq_in) : seq(seq_in) {} diff --git a/eval/src/vespa/eval/instruction/CMakeLists.txt b/eval/src/vespa/eval/instruction/CMakeLists.txt index e5aae50750d..71d08f601dd 100644 --- a/eval/src/vespa/eval/instruction/CMakeLists.txt +++ b/eval/src/vespa/eval/instruction/CMakeLists.txt @@ -3,5 +3,6 @@ vespa_add_library(eval_instruction OBJECT SOURCES generic_join + generic_merge generic_rename ) diff --git a/eval/src/vespa/eval/instruction/generic_merge.cpp b/eval/src/vespa/eval/instruction/generic_merge.cpp new file mode 100644 index 00000000000..9d8ac2bb80a --- /dev/null +++ b/eval/src/vespa/eval/instruction/generic_merge.cpp @@ -0,0 +1,147 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "generic_merge.h" +#include <vespa/eval/eval/inline_operation.h> +#include <vespa/vespalib/util/stash.h> +#include <vespa/vespalib/util/typify.h> +#include <cassert> + +namespace vespalib::eval::instruction { + +using State = InterpretedFunction::State; +using Instruction = InterpretedFunction::Instruction; + +namespace { + +//----------------------------------------------------------------------------- + +template <typename T, typename IN> uint64_t wrap_param(const IN &value_in) { + const T &value = value_in; + static_assert(sizeof(uint64_t) == sizeof(&value)); + return (uint64_t)&value; +} + +template <typename T> const T &unwrap_param(uint64_t param) { + return *((const T *)param); +} + +struct MergeParam { + const ValueType res_type; + const join_fun_t function; + const size_t num_mapped_dimensions; + const size_t dense_subspace_size; + std::vector<size_t> all_view_dims; + const ValueBuilderFactory &factory; + MergeParam(const ValueType &lhs_type, const ValueType &rhs_type, + join_fun_t function_in, const ValueBuilderFactory &factory_in) + : res_type(ValueType::join(lhs_type, rhs_type)), + function(function_in), + num_mapped_dimensions(lhs_type.count_mapped_dimensions()), + dense_subspace_size(lhs_type.dense_subspace_size()), + all_view_dims(num_mapped_dimensions), + factory(factory_in) + { + assert(!res_type.is_error()); + assert(num_mapped_dimensions == rhs_type.count_mapped_dimensions()); + assert(num_mapped_dimensions == res_type.count_mapped_dimensions()); + assert(dense_subspace_size == rhs_type.dense_subspace_size()); + assert(dense_subspace_size == res_type.dense_subspace_size()); + for (size_t i = 0; i < num_mapped_dimensions; ++i) { + all_view_dims[i] = i; + } + } + ~MergeParam(); +}; +MergeParam::~MergeParam() = default; + +//----------------------------------------------------------------------------- + +template <typename LCT, typename RCT, typename OCT, typename Fun> +std::unique_ptr<Value> +generic_mixed_merge(const Value &a, const Value &b, + const MergeParam ¶ms) +{ + Fun fun(params.function); + auto lhs_cells = a.cells().typify<LCT>(); + auto rhs_cells = b.cells().typify<RCT>(); + const size_t num_mapped = params.num_mapped_dimensions; + const size_t subspace_size = params.dense_subspace_size; + size_t guess_subspaces = std::max(a.index().size(), b.index().size()); + auto builder = params.factory.create_value_builder<OCT>(params.res_type, num_mapped, subspace_size, guess_subspaces); + std::vector<vespalib::stringref> address(num_mapped); + std::vector<const vespalib::stringref *> addr_cref; + std::vector<vespalib::stringref *> addr_ref; + for (auto & ref : address) { + addr_cref.push_back(&ref); + addr_ref.push_back(&ref); + } + size_t lhs_subspace; + size_t rhs_subspace; + auto inner = b.index().create_view(params.all_view_dims); + auto outer = a.index().create_view({}); + outer->lookup({}); + while (outer->next_result(addr_ref, lhs_subspace)) { + OCT *dst = builder->add_subspace(address).begin(); + inner->lookup(addr_cref); + if (inner->next_result({}, rhs_subspace)) { + const LCT *lhs_src = &lhs_cells[lhs_subspace * subspace_size]; + const RCT *rhs_src = &rhs_cells[rhs_subspace * subspace_size]; + for (size_t i = 0; i < subspace_size; ++i) { + *dst++ = fun(*lhs_src++, *rhs_src++); + } + } else { + const LCT *src = &lhs_cells[lhs_subspace * subspace_size]; + for (size_t i = 0; i < subspace_size; ++i) { + *dst++ = *src++; + } + } + } + inner = a.index().create_view(params.all_view_dims); + outer = b.index().create_view({}); + outer->lookup({}); + while (outer->next_result(addr_ref, rhs_subspace)) { + inner->lookup(addr_cref); + if (! inner->next_result({}, lhs_subspace)) { + OCT *dst = builder->add_subspace(address).begin(); + const RCT *src = &rhs_cells[rhs_subspace * subspace_size]; + for (size_t i = 0; i < subspace_size; ++i) { + *dst++ = *src++; + } + } + } + return builder->build(std::move(builder)); +} + +template <typename LCT, typename RCT, typename OCT, typename Fun> +void my_mixed_merge_op(State &state, uint64_t param_in) { + const auto ¶m = unwrap_param<MergeParam>(param_in); + const Value &lhs = state.peek(1); + const Value &rhs = state.peek(0); + auto up = generic_mixed_merge<LCT, RCT, OCT, Fun>(lhs, rhs, param); + auto &result = state.stash.create<std::unique_ptr<Value>>(std::move(up)); + const Value &result_ref = *(result.get()); + state.pop_pop_push(result_ref); +}; + +struct SelectGenericMergeOp { + template <typename LCT, typename RCT, typename OCT, typename Fun> static auto invoke() { + return my_mixed_merge_op<LCT,RCT,OCT,Fun>; + } +}; + +//----------------------------------------------------------------------------- + +} // namespace <unnamed> + +using MergeTypify = TypifyValue<TypifyCellType,operation::TypifyOp2>; + +Instruction +GenericMerge::make_instruction(const ValueType &lhs_type, const ValueType &rhs_type, join_fun_t function, + const ValueBuilderFactory &factory, Stash &stash) +{ + const auto ¶m = stash.create<MergeParam>(lhs_type, rhs_type, function, factory); + auto fun = typify_invoke<4,MergeTypify,SelectGenericMergeOp>(lhs_type.cell_type(), rhs_type.cell_type(), param.res_type.cell_type(), function); + return Instruction(fun, wrap_param<MergeParam>(param)); +} + +} // namespace diff --git a/eval/src/vespa/eval/instruction/generic_merge.h b/eval/src/vespa/eval/instruction/generic_merge.h new file mode 100644 index 00000000000..02e2d18715a --- /dev/null +++ b/eval/src/vespa/eval/instruction/generic_merge.h @@ -0,0 +1,15 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "generic_join.h" + +namespace vespalib::eval::instruction { + +struct GenericMerge { + static InterpretedFunction::Instruction + make_instruction(const ValueType &lhs_type, const ValueType &rhs_type, join_fun_t function, + const ValueBuilderFactory &factory, Stash &stash); +}; + +} // namespace diff --git a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp index c246ccf16e1..a0b691872a0 100644 --- a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp +++ b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp @@ -131,6 +131,18 @@ SparseBinaryFormat::serialize(nbostream &stream, const Tensor &tensor) stream.write(cells.peek(), cells.size()); } +struct BuildSparseCells { + template<typename CT> + static auto invoke(ValueType type, nbostream &stream, + size_t dimensionsSize, + size_t cellsSize) + { + DirectSparseTensorBuilder<CT> builder(std::move(type)); + decodeCells<CT>(stream, dimensionsSize, cellsSize, builder); + return builder.build(); + } +}; + std::unique_ptr<Tensor> SparseBinaryFormat::deserialize(nbostream &stream, CellType cell_type) { @@ -143,19 +155,8 @@ SparseBinaryFormat::deserialize(nbostream &stream, CellType cell_type) } size_t cellsSize = stream.getInt1_4Bytes(); ValueType type = ValueType::tensor_type(std::move(dimensions), cell_type); - switch (cell_type) { - case CellType::DOUBLE: { - DirectSparseTensorBuilder<double> builder(type); - builder.reserve(cellsSize); - decodeCells<double>(stream, dimensionsSize, cellsSize, builder); - return builder.build(); } - case CellType::FLOAT: { - DirectSparseTensorBuilder<float> builder(type); - builder.reserve(cellsSize); - decodeCells<float>(stream, dimensionsSize, cellsSize, builder); - return builder.build(); } - } - abort(); + return typify_invoke<1,eval::TypifyCellType,BuildSparseCells>(cell_type, + std::move(type), stream, dimensionsSize, cellsSize); } } // namespace |