summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2020-10-02 15:08:40 +0200
committerGitHub <noreply@github.com>2020-10-02 15:08:40 +0200
commitc67283d48af378f25cf1bd3ed8e578d0e529813f (patch)
tree6fb45410c5b117ba4dbe1cdf511dd30e57fc5292
parent6867d23640ee3a1b1d037dd3a1316ca323701721 (diff)
parentd2a77e011a670b051f3dc3d7a70ed16d65ed6597 (diff)
Merge pull request #14680 from vespa-engine/arnej/add-generic-mixed-merge
Arnej/add generic mixed merge.
-rw-r--r--eval/CMakeLists.txt1
-rw-r--r--eval/src/tests/instruction/generic_merge/CMakeLists.txt9
-rw-r--r--eval/src/tests/instruction/generic_merge/generic_merge_test.cpp82
-rw-r--r--eval/src/vespa/eval/eval/test/tensor_model.hpp2
-rw-r--r--eval/src/vespa/eval/instruction/CMakeLists.txt1
-rw-r--r--eval/src/vespa/eval/instruction/generic_merge.cpp147
-rw-r--r--eval/src/vespa/eval/instruction/generic_merge.h15
-rw-r--r--eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp27
8 files changed, 270 insertions, 14 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt
index 03a19b8c2c4..1872ee1410a 100644
--- a/eval/CMakeLists.txt
+++ b/eval/CMakeLists.txt
@@ -33,6 +33,7 @@ vespa_define_module(
src/tests/eval/value_type
src/tests/gp/ponder_nov2017
src/tests/instruction/generic_join
+ src/tests/instruction/generic_merge
src/tests/instruction/generic_rename
src/tests/tensor/default_value_builder_factory
src/tests/tensor/dense_add_dimension_optimizer
diff --git a/eval/src/tests/instruction/generic_merge/CMakeLists.txt b/eval/src/tests/instruction/generic_merge/CMakeLists.txt
new file mode 100644
index 00000000000..154b04cb32f
--- /dev/null
+++ b/eval/src/tests/instruction/generic_merge/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(eval_generic_merge_test_app TEST
+ SOURCES
+ generic_merge_test.cpp
+ DEPENDS
+ vespaeval
+ GTest::GTest
+)
+vespa_add_test(NAME eval_generic_merge_test_app COMMAND eval_generic_merge_test_app)
diff --git a/eval/src/tests/instruction/generic_merge/generic_merge_test.cpp b/eval/src/tests/instruction/generic_merge/generic_merge_test.cpp
new file mode 100644
index 00000000000..501d5410b87
--- /dev/null
+++ b/eval/src/tests/instruction/generic_merge/generic_merge_test.cpp
@@ -0,0 +1,82 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/eval/eval/simple_value.h>
+#include <vespa/eval/eval/value_codec.h>
+#include <vespa/eval/instruction/generic_merge.h>
+#include <vespa/eval/eval/interpreted_function.h>
+#include <vespa/eval/eval/test/tensor_model.hpp>
+#include <vespa/vespalib/util/stringfmt.h>
+#include <vespa/vespalib/gtest/gtest.h>
+#include <optional>
+
+using namespace vespalib;
+using namespace vespalib::eval;
+using namespace vespalib::eval::instruction;
+using namespace vespalib::eval::test;
+
+using vespalib::make_string_short::fmt;
+
+std::vector<Layout> merge_layouts = {
+ {}, {},
+ {x(5)}, {x(5)},
+ {x(3),y(5)}, {x(3),y(5)},
+ float_cells({x(3),y(5)}), {x(3),y(5)},
+ {x(3),y(5)}, float_cells({x(3),y(5)}),
+ {x({"a","b","c"})}, {x({"a","b","c"})},
+ {x({"a","b","c"})}, {x({"c","d","e"})},
+ {x({"a","c","e"})}, {x({"b","c","d"})},
+ {x({"b","c","d"})}, {x({"a","c","e"})},
+ {x({"a","b","c"})}, {x({"c","d"})},
+ {x({"a","b"}),y({"foo","bar","baz"})}, {x({"b","c"}),y({"any","foo","bar"})},
+ {x(3),y({"foo", "bar"})}, {x(3),y({"baz", "bar"})},
+ {x({"a","b","c"}),y(5)}, {x({"b","c","d"}),y(5)}
+};
+
+
+TensorSpec reference_merge(const TensorSpec &a, const TensorSpec &b, join_fun_t fun) {
+ ValueType res_type = ValueType::merge(ValueType::from_spec(a.type()),
+ ValueType::from_spec(b.type()));
+ EXPECT_FALSE(res_type.is_error());
+ TensorSpec result(res_type.to_spec());
+ for (const auto &cell: a.cells()) {
+ auto other = b.cells().find(cell.first);
+ if (other == b.cells().end()) {
+ result.add(cell.first, cell.second);
+ } else {
+ result.add(cell.first, fun(cell.second, other->second));
+ }
+ }
+ for (const auto &cell: b.cells()) {
+ auto other = a.cells().find(cell.first);
+ if (other == a.cells().end()) {
+ result.add(cell.first, cell.second);
+ }
+ }
+ return result;
+}
+
+TensorSpec perform_generic_merge(const TensorSpec &a, const TensorSpec &b, join_fun_t fun) {
+ Stash stash;
+ const auto &factory = SimpleValueBuilderFactory::get();
+ auto lhs = value_from_spec(a, factory);
+ auto rhs = value_from_spec(b, factory);
+ auto my_op = GenericMerge::make_instruction(lhs->type(), rhs->type(), fun, factory, stash);
+ InterpretedFunction::EvalSingle single(my_op);
+ return spec_from_value(single.eval(std::vector<Value::CREF>({*lhs, *rhs})));
+}
+
+TEST(GenericMergeTest, generic_merge_works_for_simple_values) {
+ ASSERT_TRUE((merge_layouts.size() % 2) == 0);
+ for (size_t i = 0; i < merge_layouts.size(); i += 2) {
+ TensorSpec lhs = spec(merge_layouts[i], N());
+ TensorSpec rhs = spec(merge_layouts[i + 1], Div16(N()));
+ SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str()));
+ for (auto fun: {operation::Add::f, operation::Mul::f, operation::Sub::f, operation::Max::f}) {
+ auto expect = reference_merge(lhs, rhs, fun);
+ auto actual = perform_generic_merge(lhs, rhs, fun);
+ EXPECT_EQ(actual, expect);
+ }
+ }
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/eval/src/vespa/eval/eval/test/tensor_model.hpp b/eval/src/vespa/eval/eval/test/tensor_model.hpp
index 42f0dc7e996..4e4ef60aaee 100644
--- a/eval/src/vespa/eval/eval/test/tensor_model.hpp
+++ b/eval/src/vespa/eval/eval/test/tensor_model.hpp
@@ -38,7 +38,7 @@ struct Div10 : Sequence {
double operator[](size_t i) const override { return (seq[i] / 10.0); }
};
-// Sequence of another sequence divided by 10
+// Sequence of another sequence divided by 16
struct Div16 : Sequence {
const Sequence &seq;
Div16(const Sequence &seq_in) : seq(seq_in) {}
diff --git a/eval/src/vespa/eval/instruction/CMakeLists.txt b/eval/src/vespa/eval/instruction/CMakeLists.txt
index e5aae50750d..71d08f601dd 100644
--- a/eval/src/vespa/eval/instruction/CMakeLists.txt
+++ b/eval/src/vespa/eval/instruction/CMakeLists.txt
@@ -3,5 +3,6 @@
vespa_add_library(eval_instruction OBJECT
SOURCES
generic_join
+ generic_merge
generic_rename
)
diff --git a/eval/src/vespa/eval/instruction/generic_merge.cpp b/eval/src/vespa/eval/instruction/generic_merge.cpp
new file mode 100644
index 00000000000..9d8ac2bb80a
--- /dev/null
+++ b/eval/src/vespa/eval/instruction/generic_merge.cpp
@@ -0,0 +1,147 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "generic_merge.h"
+#include <vespa/eval/eval/inline_operation.h>
+#include <vespa/vespalib/util/stash.h>
+#include <vespa/vespalib/util/typify.h>
+#include <cassert>
+
+namespace vespalib::eval::instruction {
+
+using State = InterpretedFunction::State;
+using Instruction = InterpretedFunction::Instruction;
+
+namespace {
+
+//-----------------------------------------------------------------------------
+
+template <typename T, typename IN> uint64_t wrap_param(const IN &value_in) {
+ const T &value = value_in;
+ static_assert(sizeof(uint64_t) == sizeof(&value));
+ return (uint64_t)&value;
+}
+
+template <typename T> const T &unwrap_param(uint64_t param) {
+ return *((const T *)param);
+}
+
+struct MergeParam {
+ const ValueType res_type;
+ const join_fun_t function;
+ const size_t num_mapped_dimensions;
+ const size_t dense_subspace_size;
+ std::vector<size_t> all_view_dims;
+ const ValueBuilderFactory &factory;
+ MergeParam(const ValueType &lhs_type, const ValueType &rhs_type,
+ join_fun_t function_in, const ValueBuilderFactory &factory_in)
+ : res_type(ValueType::join(lhs_type, rhs_type)),
+ function(function_in),
+ num_mapped_dimensions(lhs_type.count_mapped_dimensions()),
+ dense_subspace_size(lhs_type.dense_subspace_size()),
+ all_view_dims(num_mapped_dimensions),
+ factory(factory_in)
+ {
+ assert(!res_type.is_error());
+ assert(num_mapped_dimensions == rhs_type.count_mapped_dimensions());
+ assert(num_mapped_dimensions == res_type.count_mapped_dimensions());
+ assert(dense_subspace_size == rhs_type.dense_subspace_size());
+ assert(dense_subspace_size == res_type.dense_subspace_size());
+ for (size_t i = 0; i < num_mapped_dimensions; ++i) {
+ all_view_dims[i] = i;
+ }
+ }
+ ~MergeParam();
+};
+MergeParam::~MergeParam() = default;
+
+//-----------------------------------------------------------------------------
+
+template <typename LCT, typename RCT, typename OCT, typename Fun>
+std::unique_ptr<Value>
+generic_mixed_merge(const Value &a, const Value &b,
+ const MergeParam &params)
+{
+ Fun fun(params.function);
+ auto lhs_cells = a.cells().typify<LCT>();
+ auto rhs_cells = b.cells().typify<RCT>();
+ const size_t num_mapped = params.num_mapped_dimensions;
+ const size_t subspace_size = params.dense_subspace_size;
+ size_t guess_subspaces = std::max(a.index().size(), b.index().size());
+ auto builder = params.factory.create_value_builder<OCT>(params.res_type, num_mapped, subspace_size, guess_subspaces);
+ std::vector<vespalib::stringref> address(num_mapped);
+ std::vector<const vespalib::stringref *> addr_cref;
+ std::vector<vespalib::stringref *> addr_ref;
+ for (auto & ref : address) {
+ addr_cref.push_back(&ref);
+ addr_ref.push_back(&ref);
+ }
+ size_t lhs_subspace;
+ size_t rhs_subspace;
+ auto inner = b.index().create_view(params.all_view_dims);
+ auto outer = a.index().create_view({});
+ outer->lookup({});
+ while (outer->next_result(addr_ref, lhs_subspace)) {
+ OCT *dst = builder->add_subspace(address).begin();
+ inner->lookup(addr_cref);
+ if (inner->next_result({}, rhs_subspace)) {
+ const LCT *lhs_src = &lhs_cells[lhs_subspace * subspace_size];
+ const RCT *rhs_src = &rhs_cells[rhs_subspace * subspace_size];
+ for (size_t i = 0; i < subspace_size; ++i) {
+ *dst++ = fun(*lhs_src++, *rhs_src++);
+ }
+ } else {
+ const LCT *src = &lhs_cells[lhs_subspace * subspace_size];
+ for (size_t i = 0; i < subspace_size; ++i) {
+ *dst++ = *src++;
+ }
+ }
+ }
+ inner = a.index().create_view(params.all_view_dims);
+ outer = b.index().create_view({});
+ outer->lookup({});
+ while (outer->next_result(addr_ref, rhs_subspace)) {
+ inner->lookup(addr_cref);
+ if (! inner->next_result({}, lhs_subspace)) {
+ OCT *dst = builder->add_subspace(address).begin();
+ const RCT *src = &rhs_cells[rhs_subspace * subspace_size];
+ for (size_t i = 0; i < subspace_size; ++i) {
+ *dst++ = *src++;
+ }
+ }
+ }
+ return builder->build(std::move(builder));
+}
+
+template <typename LCT, typename RCT, typename OCT, typename Fun>
+void my_mixed_merge_op(State &state, uint64_t param_in) {
+ const auto &param = unwrap_param<MergeParam>(param_in);
+ const Value &lhs = state.peek(1);
+ const Value &rhs = state.peek(0);
+ auto up = generic_mixed_merge<LCT, RCT, OCT, Fun>(lhs, rhs, param);
+ auto &result = state.stash.create<std::unique_ptr<Value>>(std::move(up));
+ const Value &result_ref = *(result.get());
+ state.pop_pop_push(result_ref);
+};
+
+struct SelectGenericMergeOp {
+ template <typename LCT, typename RCT, typename OCT, typename Fun> static auto invoke() {
+ return my_mixed_merge_op<LCT,RCT,OCT,Fun>;
+ }
+};
+
+//-----------------------------------------------------------------------------
+
+} // namespace <unnamed>
+
+using MergeTypify = TypifyValue<TypifyCellType,operation::TypifyOp2>;
+
+Instruction
+GenericMerge::make_instruction(const ValueType &lhs_type, const ValueType &rhs_type, join_fun_t function,
+ const ValueBuilderFactory &factory, Stash &stash)
+{
+ const auto &param = stash.create<MergeParam>(lhs_type, rhs_type, function, factory);
+ auto fun = typify_invoke<4,MergeTypify,SelectGenericMergeOp>(lhs_type.cell_type(), rhs_type.cell_type(), param.res_type.cell_type(), function);
+ return Instruction(fun, wrap_param<MergeParam>(param));
+}
+
+} // namespace
diff --git a/eval/src/vespa/eval/instruction/generic_merge.h b/eval/src/vespa/eval/instruction/generic_merge.h
new file mode 100644
index 00000000000..02e2d18715a
--- /dev/null
+++ b/eval/src/vespa/eval/instruction/generic_merge.h
@@ -0,0 +1,15 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "generic_join.h"
+
+namespace vespalib::eval::instruction {
+
+struct GenericMerge {
+ static InterpretedFunction::Instruction
+ make_instruction(const ValueType &lhs_type, const ValueType &rhs_type, join_fun_t function,
+ const ValueBuilderFactory &factory, Stash &stash);
+};
+
+} // namespace
diff --git a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp
index c246ccf16e1..a0b691872a0 100644
--- a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp
+++ b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp
@@ -131,6 +131,18 @@ SparseBinaryFormat::serialize(nbostream &stream, const Tensor &tensor)
stream.write(cells.peek(), cells.size());
}
+struct BuildSparseCells {
+ template<typename CT>
+ static auto invoke(ValueType type, nbostream &stream,
+ size_t dimensionsSize,
+ size_t cellsSize)
+ {
+ DirectSparseTensorBuilder<CT> builder(std::move(type));
+ decodeCells<CT>(stream, dimensionsSize, cellsSize, builder);
+ return builder.build();
+ }
+};
+
std::unique_ptr<Tensor>
SparseBinaryFormat::deserialize(nbostream &stream, CellType cell_type)
{
@@ -143,19 +155,8 @@ SparseBinaryFormat::deserialize(nbostream &stream, CellType cell_type)
}
size_t cellsSize = stream.getInt1_4Bytes();
ValueType type = ValueType::tensor_type(std::move(dimensions), cell_type);
- switch (cell_type) {
- case CellType::DOUBLE: {
- DirectSparseTensorBuilder<double> builder(type);
- builder.reserve(cellsSize);
- decodeCells<double>(stream, dimensionsSize, cellsSize, builder);
- return builder.build(); }
- case CellType::FLOAT: {
- DirectSparseTensorBuilder<float> builder(type);
- builder.reserve(cellsSize);
- decodeCells<float>(stream, dimensionsSize, cellsSize, builder);
- return builder.build(); }
- }
- abort();
+ return typify_invoke<1,eval::TypifyCellType,BuildSparseCells>(cell_type,
+ std::move(type), stream, dimensionsSize, cellsSize);
}
} // namespace