diff options
author | Håvard Pettersen <havardpe@oath.com> | 2020-01-03 13:17:57 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2020-01-13 15:41:37 +0000 |
commit | b84803f6ec11b4881da067cf2ef8ae35335886b7 (patch) | |
tree | d1fb93dd0d462c0a547ceaa0cd881c733450f164 | |
parent | d220c3dd187e908afa015d5990e6dbbeb2e9876b (diff) |
tensor merge
36 files changed, 403 insertions, 17 deletions
diff --git a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp index b01c849da1e..a19bff415e1 100644 --- a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp +++ b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp @@ -51,6 +51,7 @@ TEST("require that lazy parameter passing works") { std::vector<vespalib::string> unsupported = { "map(", "join(", + "merge(", "reduce(", "rename(", "tensor(", diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp index dceaf279594..07bca5b53a6 100644 --- a/eval/src/tests/eval/node_types/node_types_test.cpp +++ b/eval/src/tests/eval/node_types/node_types_test.cpp @@ -195,6 +195,28 @@ TEST("require that join resolves correct type") { TEST_DO(verify_op2("join(%s,%s,f(x,y)(x+y))")); } +TEST("require that merge resolves to the appropriate type") { + const char *pattern = "merge(%s,%s,f(x,y)(x+y))"; + TEST_DO(verify(strfmt(pattern, "error", "error"), "error")); + TEST_DO(verify(strfmt(pattern, "double", "error"), "error")); + TEST_DO(verify(strfmt(pattern, "error", "double"), "error")); + TEST_DO(verify(strfmt(pattern, "tensor(x{})", "error"), "error")); + TEST_DO(verify(strfmt(pattern, "error", "tensor(x{})"), "error")); + TEST_DO(verify(strfmt(pattern, "double", "double"), "double")); + TEST_DO(verify(strfmt(pattern, "tensor(x{})", "double"), "error")); + TEST_DO(verify(strfmt(pattern, "double", "tensor(x{})"), "error")); + TEST_DO(verify(strfmt(pattern, "tensor(x{})", "tensor(x{})"), "tensor(x{})")); + TEST_DO(verify(strfmt(pattern, "tensor(x{})", "tensor(y{})"), "error")); + TEST_DO(verify(strfmt(pattern, "tensor(x[5])", "tensor(x[5])"), "tensor(x[5])")); + TEST_DO(verify(strfmt(pattern, "tensor(x[3])", "tensor(x[5])"), "error")); + TEST_DO(verify(strfmt(pattern, "tensor(x[5])", "tensor(x[3])"), "error")); + TEST_DO(verify(strfmt(pattern, "tensor(x{})", "tensor(x[5])"), "error")); + TEST_DO(verify(strfmt(pattern, "tensor<float>(x[5])", "tensor<float>(x[5])"), "tensor<float>(x[5])")); + TEST_DO(verify(strfmt(pattern, "tensor<float>(x[5])", "tensor(x[5])"), "tensor(x[5])")); + TEST_DO(verify(strfmt(pattern, "tensor(x[5])", "tensor<float>(x[5])"), "tensor(x[5])")); + TEST_DO(verify(strfmt(pattern, "tensor<float>(x[5])", "double"), "error")); +} + TEST("require that lambda tensor resolves correct type") { TEST_DO(verify("tensor(x[5])(1.0)", "tensor(x[5])")); TEST_DO(verify("tensor(x[5],y[10])(1.0)", "tensor(x[5],y[10])")); diff --git a/eval/src/tests/eval/simple_tensor/simple_tensor_test.cpp b/eval/src/tests/eval/simple_tensor/simple_tensor_test.cpp index 9fb56288dda..f3bea3820c8 100644 --- a/eval/src/tests/eval/simple_tensor/simple_tensor_test.cpp +++ b/eval/src/tests/eval/simple_tensor/simple_tensor_test.cpp @@ -106,6 +106,30 @@ TEST("require that simple tensors can be multiplied with each other") { EXPECT_EQUAL(to_spec(*expect), to_spec(unwrap(result2))); } +TEST("require that simple tensors can be merged") { + auto lhs = SimpleTensor::create( + TensorSpec("tensor(x{},y{})") + .add({{"x","1"},{"y","1"}}, 1) + .add({{"x","2"},{"y","1"}}, 3) + .add({{"x","1"},{"y","2"}}, 5)); + auto rhs = SimpleTensor::create( + TensorSpec("tensor(x{},y{})") + .add({{"x","1"},{"y","2"}}, 7) + .add({{"x","2"},{"y","2"}}, 11) + .add({{"x","1"},{"y","1"}}, 13)); + auto expect = SimpleTensor::create( + TensorSpec("tensor(x{},y{})") + .add({{"x","2"},{"y","1"}}, 3) + .add({{"x","1"},{"y","2"}}, 7) + .add({{"x","2"},{"y","2"}}, 11) + .add({{"x","1"},{"y","1"}}, 13)); + auto result = SimpleTensor::merge(*lhs, *rhs, [](double, double b){ return b; }); + EXPECT_EQUAL(to_spec(*expect), to_spec(*result)); + Stash stash; + const Value &result2 = SimpleTensorEngine::ref().merge(*lhs, *rhs, [](double, double b){ return b; }, stash); + EXPECT_EQUAL(to_spec(*expect), to_spec(unwrap(result2))); +} + TEST("require that simple tensors support dimension reduction") { auto tensor = SimpleTensor::create( TensorSpec("tensor(x[3],y[2])") diff --git a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp index 1e59b35a01f..1eb9912abd2 100644 --- a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp +++ b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp @@ -154,6 +154,28 @@ struct EvalCtx { .add({{"x","2"},{"y","1"},{"z","2"}}, 39) .add({{"x","1"},{"y","2"},{"z","1"}}, 55)); } + Value::UP make_tensor_merge_lhs() { + return engine.from_spec( + TensorSpec("tensor(x{})") + .add({{"x","1"}}, 1) + .add({{"x","2"}}, 3) + .add({{"x","3"}}, 5)); + } + Value::UP make_tensor_merge_rhs() { + return engine.from_spec( + TensorSpec("tensor(x{})") + .add({{"x","2"}}, 7) + .add({{"x","3"}}, 9) + .add({{"x","4"}}, 11)); + } + Value::UP make_tensor_merge_output() { + return engine.from_spec( + TensorSpec("tensor(x{})") + .add({{"x","1"}}, 1) + .add({{"x","2"}}, 10) + .add({{"x","3"}}, 14) + .add({{"x","4"}}, 11)); + } }; void verify_equal(const Value &expect, const Value &value) { @@ -237,6 +259,20 @@ TEST("require that tensor join works") { TEST_DO(verify_equal(*expect, ctx.eval(prog))); } +TEST("require that tensor merge works") { + EvalCtx ctx(SimpleTensorEngine::ref()); + size_t a_id = ctx.add_tensor(ctx.make_tensor_merge_lhs()); + size_t b_id = ctx.add_tensor(ctx.make_tensor_merge_rhs()); + Value::UP expect = ctx.make_tensor_merge_output(); + const auto &fun = merge(inject(ValueType::from_spec("tensor(x{})"), a_id, ctx.stash), + inject(ValueType::from_spec("tensor(x{})"), b_id, ctx.stash), + operation::Add::f, ctx.stash); + EXPECT_TRUE(fun.result_is_mutable()); + EXPECT_EQUAL(expect->type(), fun.result_type()); + const auto &prog = ctx.compile(fun); + TEST_DO(verify_equal(*expect, ctx.eval(prog))); +} + TEST("require that tensor concat works") { EvalCtx ctx(SimpleTensorEngine::ref()); size_t a_id = ctx.add_tensor(ctx.make_tensor_matrix_first_half()); @@ -412,20 +448,25 @@ TEST("require that push_children works") { EXPECT_EQUAL(&refs[2].get().get(), &a); EXPECT_EQUAL(&refs[3].get().get(), &b); //------------------------------------------------------------------------- - concat(a, b, "x", stash).push_children(refs); + merge(a, b, operation::Add::f, stash).push_children(refs); ASSERT_EQUAL(refs.size(), 6u); EXPECT_EQUAL(&refs[4].get().get(), &a); EXPECT_EQUAL(&refs[5].get().get(), &b); //------------------------------------------------------------------------- + concat(a, b, "x", stash).push_children(refs); + ASSERT_EQUAL(refs.size(), 8u); + EXPECT_EQUAL(&refs[6].get().get(), &a); + EXPECT_EQUAL(&refs[7].get().get(), &b); + //------------------------------------------------------------------------- rename(c, {}, {}, stash).push_children(refs); - ASSERT_EQUAL(refs.size(), 7u); - EXPECT_EQUAL(&refs[6].get().get(), &c); + ASSERT_EQUAL(refs.size(), 9u); + EXPECT_EQUAL(&refs[8].get().get(), &c); //------------------------------------------------------------------------- if_node(a, b, c, stash).push_children(refs); - ASSERT_EQUAL(refs.size(), 10u); - EXPECT_EQUAL(&refs[7].get().get(), &a); - EXPECT_EQUAL(&refs[8].get().get(), &b); - EXPECT_EQUAL(&refs[9].get().get(), &c); + ASSERT_EQUAL(refs.size(), 12u); + EXPECT_EQUAL(&refs[9].get().get(), &a); + EXPECT_EQUAL(&refs[10].get().get(), &b); + EXPECT_EQUAL(&refs[11].get().get(), &c); //------------------------------------------------------------------------- } @@ -433,6 +474,7 @@ TEST("require that tensor function can be dumped for debugging") { Stash stash; auto my_value_1 = stash.create<DoubleValue>(1.0); auto my_value_2 = stash.create<DoubleValue>(2.0); + auto my_value_3 = stash.create<DoubleValue>(3.0); //------------------------------------------------------------------------- const auto &x5 = inject(ValueType::from_spec("tensor(x[5])"), 0, stash); const auto &mapped_x5 = map(x5, operation::Relu::f, stash); @@ -452,7 +494,9 @@ TEST("require that tensor function can be dumped for debugging") { const auto &concat_x5 = concat(x3, x2, "x", stash); //------------------------------------------------------------------------- const auto &const_2 = const_value(my_value_2, stash); - const auto &root = if_node(const_2, joined_x5, concat_x5, stash); + const auto &const_3 = const_value(my_value_3, stash); + const auto &merged_double = merge(const_2, const_3, operation::Less::f, stash); + const auto &root = if_node(merged_double, joined_x5, concat_x5, stash); EXPECT_EQUAL(root.result_type(), ValueType::from_spec("tensor(x[5])")); fprintf(stderr, "function dump -->[[%s]]<-- function dump\n", root.as_string().c_str()); } diff --git a/eval/src/tests/eval/value_type/value_type_test.cpp b/eval/src/tests/eval/value_type/value_type_test.cpp index 85ff7613775..055c7e582bf 100644 --- a/eval/src/tests/eval/value_type/value_type_test.cpp +++ b/eval/src/tests/eval/value_type/value_type_test.cpp @@ -408,6 +408,25 @@ TEST("require that tensor dimensions can be renamed") { EXPECT_EQUAL(type("error").rename({"a"}, {"b"}), type("error")); } +TEST("require that similar types can be merged") { + EXPECT_EQUAL(ValueType::merge(type("error"), type("error")), type("error")); + EXPECT_EQUAL(ValueType::merge(type("double"), type("double")), type("double")); + EXPECT_EQUAL(ValueType::merge(type("tensor(x[5])"), type("tensor(x[5])")), type("tensor(x[5])")); + EXPECT_EQUAL(ValueType::merge(type("tensor<float>(x[5])"), type("tensor(x[5])")), type("tensor(x[5])")); + EXPECT_EQUAL(ValueType::merge(type("tensor(x[5])"), type("tensor<float>(x[5])")), type("tensor(x[5])")); + EXPECT_EQUAL(ValueType::merge(type("tensor<float>(x[5])"), type("tensor<float>(x[5])")), type("tensor<float>(x[5])")); + EXPECT_EQUAL(ValueType::merge(type("tensor(x{})"), type("tensor(x{})")), type("tensor(x{})")); +} + +TEST("require that diverging types can not be merged") { + EXPECT_EQUAL(ValueType::merge(type("error"), type("double")), type("error")); + EXPECT_EQUAL(ValueType::merge(type("double"), type("error")), type("error")); + EXPECT_EQUAL(ValueType::merge(type("tensor(x[5])"), type("double")), type("error")); + EXPECT_EQUAL(ValueType::merge(type("double"), type("tensor(x[5])")), type("error")); + EXPECT_EQUAL(ValueType::merge(type("tensor(x[5])"), type("tensor(x[3])")), type("error")); + EXPECT_EQUAL(ValueType::merge(type("tensor(x{})"), type("tensor(y{})")), type("error")); +} + void verify_concat(const ValueType &a, const ValueType b, const vespalib::string &dim, const ValueType &res) { EXPECT_EQUAL(ValueType::concat(a, b, dim), res); EXPECT_EQUAL(ValueType::concat(b, a, dim), res); diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp index 03e0ae5c0a2..53107fefb32 100644 --- a/eval/src/vespa/eval/eval/function.cpp +++ b/eval/src/vespa/eval/eval/function.cpp @@ -603,6 +603,15 @@ void parse_tensor_join(ParseContext &ctx) { ctx.push_expression(std::make_unique<nodes::TensorJoin>(std::move(lhs), std::move(rhs), std::move(lambda))); } +void parse_tensor_merge(ParseContext &ctx) { + Node_UP lhs = get_expression(ctx); + ctx.eat(','); + Node_UP rhs = get_expression(ctx); + ctx.eat(','); + auto lambda = parse_lambda(ctx, 2); + ctx.push_expression(std::make_unique<nodes::TensorMerge>(std::move(lhs), std::move(rhs), std::move(lambda))); +} + void parse_tensor_reduce(ParseContext &ctx) { Node_UP child = get_expression(ctx); ctx.eat(','); @@ -862,6 +871,8 @@ bool maybe_parse_call(ParseContext &ctx, const vespalib::string &name) { parse_tensor_map(ctx); } else if (name == "join") { parse_tensor_join(ctx); + } else if (name == "merge") { + parse_tensor_merge(ctx); } else if (name == "reduce") { parse_tensor_reduce(ctx); } else if (name == "rename") { diff --git a/eval/src/vespa/eval/eval/interpreted_function.cpp b/eval/src/vespa/eval/eval/interpreted_function.cpp index b93ad1eb10c..0a630a3e20a 100644 --- a/eval/src/vespa/eval/eval/interpreted_function.cpp +++ b/eval/src/vespa/eval/eval/interpreted_function.cpp @@ -28,6 +28,9 @@ const Function *get_lambda(const nodes::Node &node) { if (auto ptr = nodes::as<nodes::TensorJoin>(node)) { return &ptr->lambda(); } + if (auto ptr = nodes::as<nodes::TensorMerge>(node)) { + return &ptr->lambda(); + } return nullptr; } diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp index 0f8382b5afc..46137b5878c 100644 --- a/eval/src/vespa/eval/eval/key_gen.cpp +++ b/eval/src/vespa/eval/eval/key_gen.cpp @@ -35,8 +35,9 @@ struct KeyGen : public NodeVisitor, public NodeTraverser { void visit(const Not &) override { add_byte( 6); } void visit(const If &node) override { add_byte( 7); add_double(node.p_true()); } void visit(const Error &) override { add_byte( 9); } - void visit(const TensorMap &) override { add_byte(11); } // lambda should be part of key - void visit(const TensorJoin &) override { add_byte(12); } // lambda should be part of key + void visit(const TensorMap &) override { add_byte(10); } // lambda should be part of key + void visit(const TensorJoin &) override { add_byte(11); } // lambda should be part of key + void visit(const TensorMerge &) override { add_byte(12); } // lambda should be part of key void visit(const TensorReduce &) override { add_byte(13); } // aggr/dimensions should be part of key void visit(const TensorRename &) override { add_byte(14); } // dimensions should be part of key void visit(const TensorConcat &) override { add_byte(15); } // dimension should be part of key diff --git a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp index ca34c31e976..facfc502111 100644 --- a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp +++ b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp @@ -128,6 +128,7 @@ CompiledFunction::detect_issues(const Function &function) void close(const nodes::Node &node) override { if (nodes::check_type<nodes::TensorMap, nodes::TensorJoin, + nodes::TensorMerge, nodes::TensorReduce, nodes::TensorRename, nodes::TensorConcat, diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp index 72406f09a0d..1d5515d7f4a 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp @@ -467,6 +467,9 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { void visit(const TensorJoin &node) override { make_error(node.num_children()); } + void visit(const TensorMerge &node) override { + make_error(node.num_children()); + } void visit(const TensorReduce &node) override { make_error(node.num_children()); } diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp index 803047d27c4..849270b89a7 100644 --- a/eval/src/vespa/eval/eval/make_tensor_function.cpp +++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp @@ -77,6 +77,14 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { stack.back() = tensor_function::join(a, b, function, stash); } + void make_merge(const Node &, join_fun_t function) { + assert(stack.size() >= 2); + const auto &b = stack.back(); + stack.pop_back(); + const auto &a = stack.back(); + stack.back() = tensor_function::merge(a, b, function, stash); + } + void make_concat(const Node &, const vespalib::string &dimension) { assert(stack.size() >= 2); const auto &b = stack.back(); @@ -197,6 +205,10 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { make_join(node, token.get()->get().get_function<2>()); } } + void visit(const TensorMerge &node) override { + const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE)); + make_merge(node, token.get()->get().get_function<2>()); + } void visit(const TensorReduce &node) override { make_reduce(node, node.aggr(), node.dimensions()); } diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp index 972afb7663f..bf87628e301 100644 --- a/eval/src/vespa/eval/eval/node_types.cpp +++ b/eval/src/vespa/eval/eval/node_types.cpp @@ -97,6 +97,10 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser { } void visit(const TensorMap &node) override { resolve_op1(node); } void visit(const TensorJoin &node) override { resolve_op2(node); } + void visit(const TensorMerge &node) override { + bind(ValueType::merge(type(node.get_child(0)), + type(node.get_child(1))), node); + } void visit(const TensorReduce &node) override { const ValueType &child = type(node.get_child(0)); bind(child.reduce(node.dimensions()), node); diff --git a/eval/src/vespa/eval/eval/node_visitor.h b/eval/src/vespa/eval/eval/node_visitor.h index f32161a24c4..8f9722858b7 100644 --- a/eval/src/vespa/eval/eval/node_visitor.h +++ b/eval/src/vespa/eval/eval/node_visitor.h @@ -31,6 +31,7 @@ struct NodeVisitor { // tensor nodes virtual void visit(const nodes::TensorMap &) = 0; virtual void visit(const nodes::TensorJoin &) = 0; + virtual void visit(const nodes::TensorMerge &) = 0; virtual void visit(const nodes::TensorReduce &) = 0; virtual void visit(const nodes::TensorRename &) = 0; virtual void visit(const nodes::TensorConcat &) = 0; @@ -99,7 +100,8 @@ struct EmptyNodeVisitor : NodeVisitor { void visit(const nodes::If &) override {} void visit(const nodes::Error &) override {} void visit(const nodes::TensorMap &) override {} - void visit(const nodes::TensorJoin &) override {} + void visit(const nodes::TensorJoin &) override {} + void visit(const nodes::TensorMerge &) override {} void visit(const nodes::TensorReduce &) override {} void visit(const nodes::TensorRename &) override {} void visit(const nodes::TensorConcat &) override {} diff --git a/eval/src/vespa/eval/eval/simple_tensor.cpp b/eval/src/vespa/eval/eval/simple_tensor.cpp index b847d31335e..a72a24be211 100644 --- a/eval/src/vespa/eval/eval/simple_tensor.cpp +++ b/eval/src/vespa/eval/eval/simple_tensor.cpp @@ -3,6 +3,8 @@ #include "simple_tensor.h" #include "simple_tensor_engine.h" #include "operation.h" +#include <vespa/vespalib/util/overload.h> +#include <vespa/vespalib/util/visit_ranges.h> #include <vespa/vespalib/objects/nbostream.h> #include <algorithm> #include <cassert> @@ -678,6 +680,24 @@ SimpleTensor::join(const SimpleTensor &a, const SimpleTensor &b, join_fun_t func } std::unique_ptr<SimpleTensor> +SimpleTensor::merge(const SimpleTensor &a, const SimpleTensor &b, join_fun_t function) +{ + ValueType result_type = ValueType::merge(a.type(), b.type()); + if (result_type.is_error()) { + return std::make_unique<SimpleTensor>(); + } + Builder builder(result_type); + auto cmp = [](const Cell &x, const Cell &y) { return (x.address < y.address); }; + auto visitor = overload{ + [&builder](visit_ranges_either, const Cell &x) { builder.set(x.address, x.value); }, + [&builder,function](visit_ranges_both, const Cell &x, const Cell &y) { + builder.set(x.address, function(x.value, y.value)); + }}; + visit_ranges(visitor, a._cells.begin(), a._cells.end(), b._cells.begin(), b._cells.end(), cmp); + return builder.build(); +} + +std::unique_ptr<SimpleTensor> SimpleTensor::concat(const SimpleTensor &a, const SimpleTensor &b, const vespalib::string &dimension) { ValueType result_type = ValueType::concat(a.type(), b.type(), dimension); diff --git a/eval/src/vespa/eval/eval/simple_tensor.h b/eval/src/vespa/eval/eval/simple_tensor.h index 45d1853824d..cbf1ac99e05 100644 --- a/eval/src/vespa/eval/eval/simple_tensor.h +++ b/eval/src/vespa/eval/eval/simple_tensor.h @@ -89,6 +89,7 @@ public: std::unique_ptr<SimpleTensor> rename(const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to) const; static std::unique_ptr<SimpleTensor> create(const TensorSpec &spec); static std::unique_ptr<SimpleTensor> join(const SimpleTensor &a, const SimpleTensor &b, join_fun_t function); + static std::unique_ptr<SimpleTensor> merge(const SimpleTensor &a, const SimpleTensor &b, join_fun_t function); static std::unique_ptr<SimpleTensor> concat(const SimpleTensor &a, const SimpleTensor &b, const vespalib::string &dimension); static void encode(const SimpleTensor &tensor, nbostream &output); static std::unique_ptr<SimpleTensor> decode(nbostream &input); diff --git a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp index 25fc98fc00a..6c2a0fcd53d 100644 --- a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp +++ b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp @@ -121,6 +121,12 @@ SimpleTensorEngine::join(const Value &a, const Value &b, join_fun_t function, St } const Value & +SimpleTensorEngine::merge(const Value &a, const Value &b, join_fun_t function, Stash &stash) const +{ + return to_value(SimpleTensor::merge(to_simple(a, stash), to_simple(b, stash), function), stash); +} + +const Value & SimpleTensorEngine::reduce(const Value &a, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) const { return to_value(to_simple(a, stash).reduce(Aggregator::create(aggr, stash), dimensions), stash); diff --git a/eval/src/vespa/eval/eval/simple_tensor_engine.h b/eval/src/vespa/eval/eval/simple_tensor_engine.h index 645ec4c4be7..4c71e91c8d3 100644 --- a/eval/src/vespa/eval/eval/simple_tensor_engine.h +++ b/eval/src/vespa/eval/eval/simple_tensor_engine.h @@ -27,6 +27,7 @@ public: const Value &map(const Value &a, map_fun_t function, Stash &stash) const override; const Value &join(const Value &a, const Value &b, join_fun_t function, Stash &stash) const override; + const Value &merge(const Value &a, const Value &b, join_fun_t function, Stash &stash) const override; const Value &reduce(const Value &a, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) const override; const Value &concat(const Value &a, const Value &b, const vespalib::string &dimension, Stash &stash) const override; const Value &rename(const Value &a, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash) const override; diff --git a/eval/src/vespa/eval/eval/tensor_engine.h b/eval/src/vespa/eval/eval/tensor_engine.h index 0ba25baed8c..f85f57cfa4f 100644 --- a/eval/src/vespa/eval/eval/tensor_engine.h +++ b/eval/src/vespa/eval/eval/tensor_engine.h @@ -49,6 +49,7 @@ struct TensorEngine virtual const Value &map(const Value &a, map_fun_t function, Stash &stash) const = 0; virtual const Value &join(const Value &a, const Value &b, join_fun_t function, Stash &stash) const = 0; + virtual const Value &merge(const Value &a, const Value &b, join_fun_t function, Stash &stash) const = 0; virtual const Value &reduce(const Value &a, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) const = 0; virtual const Value &concat(const Value &a, const Value &b, const vespalib::string &dimension, Stash &stash) const = 0; virtual const Value &rename(const Value &a, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash) const = 0; diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp index 228a723e86a..45e8094570e 100644 --- a/eval/src/vespa/eval/eval/tensor_function.cpp +++ b/eval/src/vespa/eval/eval/tensor_function.cpp @@ -101,6 +101,10 @@ void op_tensor_join(State &state, uint64_t param) { state.pop_pop_push(state.engine.join(state.peek(1), state.peek(0), to_join_fun(param), state.stash)); } +void op_tensor_merge(State &state, uint64_t param) { + state.pop_pop_push(state.engine.merge(state.peek(1), state.peek(0), to_join_fun(param), state.stash)); +} + using ReduceParams = std::pair<Aggr,std::vector<vespalib::string>>; void op_tensor_reduce(State &state, uint64_t param) { const ReduceParams ¶ms = unwrap_param<ReduceParams>(param); @@ -324,6 +328,21 @@ Join::visit_self(vespalib::ObjectVisitor &visitor) const //----------------------------------------------------------------------------- Instruction +Merge::compile_self(Stash &) const +{ + return Instruction(op_tensor_merge, to_param(_function)); +} + +void +Merge::visit_self(vespalib::ObjectVisitor &visitor) const +{ + Super::visit_self(visitor); + ::visit(visitor, "function", _function); +} + +//----------------------------------------------------------------------------- + +Instruction Concat::compile_self(Stash &) const { return Instruction(op_tensor_concat, wrap_param<vespalib::string>(_dimension)); @@ -471,6 +490,11 @@ const Node &join(const Node &lhs, const Node &rhs, join_fun_t function, Stash &s return stash.create<Join>(result_type, lhs, rhs, function); } +const Node &merge(const Node &lhs, const Node &rhs, join_fun_t function, Stash &stash) { + ValueType result_type = ValueType::merge(lhs.result_type(), rhs.result_type()); + return stash.create<Merge>(result_type, lhs, rhs, function); +} + const Node &concat(const Node &lhs, const Node &rhs, const vespalib::string &dimension, Stash &stash) { ValueType result_type = ValueType::concat(lhs.result_type(), rhs.result_type(), dimension); return stash.create<Concat>(result_type, lhs, rhs, dimension); diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h index b019ab64e18..4b862e9ec6a 100644 --- a/eval/src/vespa/eval/eval/tensor_function.h +++ b/eval/src/vespa/eval/eval/tensor_function.h @@ -266,6 +266,25 @@ public: //----------------------------------------------------------------------------- +class Merge : public Op2 +{ + using Super = Op2; +private: + join_fun_t _function; +public: + Merge(const ValueType &result_type_in, + const TensorFunction &lhs_in, + const TensorFunction &rhs_in, + join_fun_t function_in) + : Op2(result_type_in, lhs_in, rhs_in), _function(function_in) {} + join_fun_t function() const { return _function; } + bool result_is_mutable() const override { return true; } + InterpretedFunction::Instruction compile_self(Stash &stash) const override; + void visit_self(vespalib::ObjectVisitor &visitor) const override; +}; + +//----------------------------------------------------------------------------- + class Concat : public Op2 { using Super = Op2; @@ -394,6 +413,7 @@ const Node &inject(const ValueType &type, size_t param_idx, Stash &stash); const Node &reduce(const Node &child, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash); const Node &map(const Node &child, map_fun_t function, Stash &stash); const Node &join(const Node &lhs, const Node &rhs, join_fun_t function, Stash &stash); +const Node &merge(const Node &lhs, const Node &rhs, join_fun_t function, Stash &stash); const Node &concat(const Node &lhs, const Node &rhs, const vespalib::string &dimension, Stash &stash); const Node &create(const ValueType &type, const std::map<TensorSpec::Address, Node::CREF> &spec, Stash &stash); const Node &peek(const Node ¶m, const std::map<vespalib::string, std::variant<TensorSpec::Label, Node::CREF>> &spec, Stash &stash); diff --git a/eval/src/vespa/eval/eval/tensor_nodes.cpp b/eval/src/vespa/eval/eval/tensor_nodes.cpp index e392248a08a..82d108300dd 100644 --- a/eval/src/vespa/eval/eval/tensor_nodes.cpp +++ b/eval/src/vespa/eval/eval/tensor_nodes.cpp @@ -9,6 +9,7 @@ namespace nodes { void TensorMap ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } void TensorJoin ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorMerge ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } void TensorReduce::accept(NodeVisitor &visitor) const { visitor.visit(*this); } void TensorRename::accept(NodeVisitor &visitor) const { visitor.visit(*this); } void TensorConcat::accept(NodeVisitor &visitor) const { visitor.visit(*this); } diff --git a/eval/src/vespa/eval/eval/tensor_nodes.h b/eval/src/vespa/eval/eval/tensor_nodes.h index 4213809f9e3..daba46c1fc5 100644 --- a/eval/src/vespa/eval/eval/tensor_nodes.h +++ b/eval/src/vespa/eval/eval/tensor_nodes.h @@ -76,6 +76,38 @@ public: } }; +class TensorMerge : public Node { +private: + Node_UP _lhs; + Node_UP _rhs; + std::shared_ptr<Function const> _lambda; +public: + TensorMerge(Node_UP lhs, Node_UP rhs, std::shared_ptr<Function const> lambda) + : _lhs(std::move(lhs)), _rhs(std::move(rhs)), _lambda(std::move(lambda)) {} + const Function &lambda() const { return *_lambda; } + vespalib::string dump(DumpContext &ctx) const override { + vespalib::string str; + str += "join("; + str += _lhs->dump(ctx); + str += ","; + str += _rhs->dump(ctx); + str += ","; + str += _lambda->dump_as_lambda(); + str += ")"; + return str; + } + void accept(NodeVisitor &visitor) const override ; + size_t num_children() const override { return 2; } + const Node &get_child(size_t idx) const override { + assert(idx < 2); + return (idx == 0) ? *_lhs : *_rhs; + } + void detach_children(NodeHandler &handler) override { + handler.handle(std::move(_lhs)); + handler.handle(std::move(_rhs)); + } +}; + class TensorReduce : public Node { private: Node_UP _child; diff --git a/eval/src/vespa/eval/eval/test/eval_spec.cpp b/eval/src/vespa/eval/eval/test/eval_spec.cpp index dc1e3405a6b..709234a1a2c 100644 --- a/eval/src/vespa/eval/eval/test/eval_spec.cpp +++ b/eval/src/vespa/eval/eval/test/eval_spec.cpp @@ -165,6 +165,8 @@ EvalSpec::add_tensor_operation_cases() { add_rule({"a", -1.0, 1.0}, "map(a,f(x)(x+x*3))", [](double x){ return (x + (x * 3)); }); add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "join(a,b,f(x,y)(x+y))", [](double x, double y){ return (x + y); }); add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "join(a,b,f(x,y)(x+y*3))", [](double x, double y){ return (x + (y * 3)); }); + add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "merge(a,b,f(x,y)(x+y))", [](double x, double y){ return (x + y); }); + add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "merge(a,b,f(x,y)(x+y*3))", [](double x, double y){ return (x + (y * 3)); }); add_rule({"a", -1.0, 1.0}, "reduce(a,avg)", [](double a){ return a; }); add_rule({"a", -1.0, 1.0}, "reduce(a,count)", [](double){ return 1.0; }); add_rule({"a", -1.0, 1.0}, "reduce(a,prod)", [](double a){ return a; }); diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp index 9242b19310d..41c6dd21e24 100644 --- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp +++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp @@ -829,6 +829,32 @@ struct TestContext { //------------------------------------------------------------------------- + void test_tensor_merge(const vespalib::string &type_base, const vespalib::string &a_str, + const vespalib::string &b_str, const vespalib::string &expect_str) + { + vespalib::string expr = "merge(a,b,f(x,y)(2*x+y))"; + for (bool a_float: {false, true}) { + for (bool b_float: {false, true}) { + bool both_float = a_float && b_float; + vespalib::string a_expr = make_string("tensor%s(%s):%s", a_float ? "<float>" : "", type_base.c_str(), a_str.c_str()); + vespalib::string b_expr = make_string("tensor%s(%s):%s", b_float ? "<float>" : "", type_base.c_str(), b_str.c_str()); + vespalib::string expect_expr = make_string("tensor%s(%s):%s", both_float ? "<float>" : "", type_base.c_str(), expect_str.c_str()); + TensorSpec a = spec(a_expr); + TensorSpec b = spec(b_expr); + TensorSpec expect = spec(expect_expr); + TEST_DO(verify_result(Expr_TT(expr).eval(engine, a, b), expect)); + } + } + } + + void test_tensor_merge() { + TEST_DO(test_tensor_merge("x[3]", "[1,2,3]", "[4,5,6]", "[6,9,12]")); + TEST_DO(test_tensor_merge("x{}", "{a:1,b:2,c:3}", "{b:4,c:5,d:6}", "{a:1,b:8,c:11,d:6}")); + TEST_DO(test_tensor_merge("x{},y[2]", "{a:[1,2],b:[3,4]}", "{b:[5,6],c:[6,7]}", "{a:[1,2],b:[11,14],c:[6,7]}")); + } + + //------------------------------------------------------------------------- + void verify_encode_decode(const TensorSpec &spec, const TensorEngine &encode_engine, const TensorEngine &decode_engine) @@ -913,6 +939,7 @@ struct TestContext { TEST_DO(test_tensor_lambda()); TEST_DO(test_tensor_create()); TEST_DO(test_tensor_peek()); + TEST_DO(test_tensor_merge()); TEST_DO(test_binary_format()); } }; diff --git a/eval/src/vespa/eval/eval/test/tensor_model.hpp b/eval/src/vespa/eval/eval/test/tensor_model.hpp index 6efb7470d55..2466701df62 100644 --- a/eval/src/vespa/eval/eval/test/tensor_model.hpp +++ b/eval/src/vespa/eval/eval/test/tensor_model.hpp @@ -5,6 +5,10 @@ #include <vespa/eval/eval/value_type.h> #include <vespa/eval/eval/operation.h> #include <vespa/eval/eval/tensor_engine.h> +#include <vespa/eval/eval/function.h> +#include <vespa/eval/eval/node_types.h> +#include <vespa/eval/eval/interpreted_function.h> +#include <vespa/eval/eval/simple_tensor_engine.h> namespace vespalib { namespace eval { @@ -292,6 +296,24 @@ TensorSpec spec(const vespalib::string &type, return spec; } +TensorSpec spec(const vespalib::string &value_expr) { + if (value_expr == "error") { + return TensorSpec("error"); + } + const auto &engine = SimpleTensorEngine::ref(); + auto fun = Function::parse(value_expr); + ASSERT_TRUE(!fun->has_error()); + ASSERT_EQUAL(fun->num_params(), 0u); + NodeTypes types(*fun, {}); + ASSERT_TRUE(!types.get_type(fun->root()).is_error()); + InterpretedFunction ifun(engine, *fun, types); + InterpretedFunction::Context ctx(ifun); + SimpleObjectParams params({}); + auto result = engine.to_spec(ifun.eval(ctx, params)); + ASSERT_TRUE(!result.cells().empty()); + return result; +} + } // namespace vespalib::eval::test } // namespace vespalib::eval } // namespace vespalib diff --git a/eval/src/vespa/eval/eval/value_type.cpp b/eval/src/vespa/eval/eval/value_type.cpp index 211a9c305a3..36c7c49c8b9 100644 --- a/eval/src/vespa/eval/eval/value_type.cpp +++ b/eval/src/vespa/eval/eval/value_type.cpp @@ -289,6 +289,20 @@ ValueType::join(const ValueType &lhs, const ValueType &rhs) return tensor_type(std::move(result.dimensions), unify(lhs._cell_type, rhs._cell_type)); } +ValueType +ValueType::merge(const ValueType &lhs, const ValueType &rhs) +{ + if ((lhs.type() != rhs.type()) || + (lhs.dimensions() != rhs.dimensions())) + { + return error_type(); + } + if (lhs.dimensions().empty()) { + return lhs; + } + return tensor_type(lhs.dimensions(), unify(lhs._cell_type, rhs._cell_type)); +} + CellType ValueType::unify_cell_types(const ValueType &a, const ValueType &b) { if (a.is_double()) { diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h index b02053be3cb..1a4182dff2a 100644 --- a/eval/src/vespa/eval/eval/value_type.h +++ b/eval/src/vespa/eval/eval/value_type.h @@ -80,6 +80,7 @@ public: static ValueType from_spec(const vespalib::string &spec, std::vector<ValueType::Dimension> &unsorted); vespalib::string to_spec() const; static ValueType join(const ValueType &lhs, const ValueType &rhs); + static ValueType merge(const ValueType &lhs, const ValueType &rhs); static CellType unify_cell_types(const ValueType &a, const ValueType &b); static ValueType concat(const ValueType &lhs, const ValueType &rhs, const vespalib::string &dimension); static ValueType either(const ValueType &one, const ValueType &other); diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp index b310ad72fc1..b4449309812 100644 --- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp +++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp @@ -84,6 +84,7 @@ const Value &to_default(const Value &value, Stash &stash) { } const Value &to_value(std::unique_ptr<Tensor> tensor, Stash &stash) { + assert(tensor); if (tensor->type().is_tensor()) { return *stash.create<Value::UP>(std::move(tensor)); } @@ -101,6 +102,10 @@ const Value &fallback_join(const Value &a, const Value &b, join_fun_t function, return to_default(simple_engine().join(to_simple(a, stash), to_simple(b, stash), function, stash), stash); } +const Value &fallback_merge(const Value &a, const Value &b, join_fun_t function, Stash &stash) { + return to_default(simple_engine().merge(to_simple(a, stash), to_simple(b, stash), function, stash), stash); +} + const Value &fallback_reduce(const Value &a, eval::Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) { return to_default(simple_engine().reduce(to_simple(a, stash), aggr, dimensions, stash), stash); } @@ -329,6 +334,25 @@ DefaultTensorEngine::join(const Value &a, const Value &b, join_fun_t function, S } const Value & +DefaultTensorEngine::merge(const Value &a, const Value &b, join_fun_t function, Stash &stash) const +{ + if (auto tensor_a = a.as_tensor()) { + auto tensor_b = b.as_tensor(); + assert(tensor_b); + assert(&tensor_a->engine() == this); + assert(&tensor_b->engine() == this); + const tensor::Tensor &my_a = static_cast<const tensor::Tensor &>(*tensor_a); + const tensor::Tensor &my_b = static_cast<const tensor::Tensor &>(*tensor_b); + if (!tensor::Tensor::supported({my_a.type(), my_b.type()})) { + return fallback_merge(a, b, function, stash); + } + return to_value(my_a.merge(function, my_b), stash); + } else { + return stash.create<DoubleValue>(function(a.as_double(), b.as_double())); + } +} + +const Value & DefaultTensorEngine::reduce(const Value &a, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) const { if (auto tensor = a.as_tensor()) { diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.h b/eval/src/vespa/eval/tensor/default_tensor_engine.h index 29f9811e170..5c39706b326 100644 --- a/eval/src/vespa/eval/tensor/default_tensor_engine.h +++ b/eval/src/vespa/eval/tensor/default_tensor_engine.h @@ -28,6 +28,7 @@ public: const Value &map(const Value &a, map_fun_t function, Stash &stash) const override; const Value &join(const Value &a, const Value &b, join_fun_t function, Stash &stash) const override; + const Value &merge(const Value &a, const Value &b, join_fun_t function, Stash &stash) const override; const Value &reduce(const Value &a, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) const override; const Value &concat(const Value &a, const Value &b, const vespalib::string &dimension, Stash &stash) const override; const Value &rename(const Value &a, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash) const override; diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp index 3fed84323ca..4ed6758dfde 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp @@ -90,7 +90,7 @@ checkDimensions(const DenseTensorView &lhs, const DenseTensorView &rhs, template <typename LCT, typename RCT, typename Function> static Tensor::UP sameShapeJoin(const ConstArrayRef<LCT> &lhs, const ConstArrayRef<RCT> &rhs, - const eval::ValueType &lhs_type, + const std::vector<eval::ValueType::Dimension> &lhs_dims, Function &&func) { size_t sz = lhs.size(); @@ -106,7 +106,7 @@ sameShapeJoin(const ConstArrayRef<LCT> &lhs, const ConstArrayRef<RCT> &rhs, } assert(rhsCellItr == rhs.cend()); assert(newCells.size() == sz); - auto newType = eval::ValueType::tensor_type(lhs_type.dimensions(), eval::get_cell_type<OCT>()); + auto newType = eval::ValueType::tensor_type(lhs_dims, eval::get_cell_type<OCT>()); return std::make_unique<DenseTensor<OCT>>(std::move(newType), std::move(newCells)); } @@ -115,10 +115,10 @@ struct CallJoin template <typename LCT, typename RCT, typename Function> static Tensor::UP call(const ConstArrayRef<LCT> &lhs, const ConstArrayRef<RCT> &rhs, - const eval::ValueType &lhs_type, + const std::vector<eval::ValueType::Dimension> &lhs_dims, Function &&func) { - return sameShapeJoin(lhs, rhs, lhs_type, std::move(func)); + return sameShapeJoin(lhs, rhs, lhs_dims, std::move(func)); } }; @@ -129,7 +129,7 @@ joinDenseTensors(const DenseTensorView &lhs, const DenseTensorView &rhs, { TypedCells lhsCells = lhs.cellsRef(); TypedCells rhsCells = rhs.cellsRef(); - return dispatch_2<CallJoin>(lhsCells, rhsCells, lhs.fast_type(), std::move(func)); + return dispatch_2<CallJoin>(lhsCells, rhsCells, lhs.fast_type().dimensions(), std::move(func)); } template <typename Function> @@ -289,7 +289,7 @@ DenseTensorView::accept(TensorVisitor &visitor) const Tensor::UP DenseTensorView::join(join_fun_t function, const Tensor &arg) const { - if (fast_type() == arg.type()) { + if (fast_type().dimensions() == arg.type().dimensions()) { if (function == eval::operation::Mul::f) { return joinDenseTensors(*this, arg, "mul", [](double a, double b) { return (a * b); }); @@ -310,6 +310,13 @@ DenseTensorView::join(join_fun_t function, const Tensor &arg) const } Tensor::UP +DenseTensorView::merge(join_fun_t function, const Tensor &arg) const +{ + assert(fast_type().dimensions() == arg.type().dimensions()); + return join(function, arg); +} + +Tensor::UP DenseTensorView::reduce_all(join_fun_t op, const std::vector<vespalib::string> &dims) const { if (op == eval::operation::Mul::f) { diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h index 778f2aa2871..33183d267c1 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h @@ -34,6 +34,7 @@ public: double as_double() const override; Tensor::UP apply(const CellFunction &func) const override; Tensor::UP join(join_fun_t function, const Tensor &arg) const override; + Tensor::UP merge(join_fun_t function, const Tensor &arg) const override; Tensor::UP reduce(join_fun_t op, const std::vector<vespalib::string> &dimensions) const override; std::unique_ptr<Tensor> modify(join_fun_t op, const CellValues &cellValues) const override; std::unique_ptr<Tensor> add(const Tensor &arg) const override; diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp index bcfbc851e6d..1fc93e8234f 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp @@ -8,6 +8,7 @@ #include "sparse_tensor_modify.h" #include "sparse_tensor_reduce.hpp" #include "sparse_tensor_remove.h" +#include "direct_sparse_tensor_builder.h" #include <vespa/eval/eval/operation.h> #include <vespa/eval/tensor/cell_values.h> #include <vespa/eval/tensor/tensor_address_builder.h> @@ -175,6 +176,30 @@ SparseTensor::join(join_fun_t function, const Tensor &arg) const } Tensor::UP +SparseTensor::merge(join_fun_t function, const Tensor &arg) const +{ + const SparseTensor *rhs = dynamic_cast<const SparseTensor *>(&arg); + assert(rhs && (fast_type().dimensions() == rhs->fast_type().dimensions())); + DirectSparseTensorBuilder builder(eval::ValueType::merge(fast_type(), rhs->fast_type())); + builder.reserve(cells().size() + rhs->cells().size()); + for (const auto &cell: cells()) { + auto pos = rhs->cells().find(cell.first); + if (pos == rhs->cells().end()) { + builder.insertCell(cell.first, cell.second); + } else { + builder.insertCell(cell.first, function(cell.second, pos->second)); + } + } + for (const auto &cell: rhs->cells()) { + auto pos = cells().find(cell.first); + if (pos == cells().end()) { + builder.insertCell(cell.first, cell.second); + } + } + return builder.build(); +} + +Tensor::UP SparseTensor::reduce(join_fun_t op, const std::vector<vespalib::string> &dimensions) const { diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h index c182c09c6b0..880cd32c605 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h @@ -44,6 +44,7 @@ public: double as_double() const override; Tensor::UP apply(const CellFunction &func) const override; Tensor::UP join(join_fun_t function, const Tensor &arg) const override; + Tensor::UP merge(join_fun_t function, const Tensor &arg) const override; Tensor::UP reduce(join_fun_t op, const std::vector<vespalib::string> &dimensions) const override; std::unique_ptr<Tensor> modify(join_fun_t op, const CellValues &cellValues) const override; std::unique_ptr<Tensor> add(const Tensor &arg) const override; diff --git a/eval/src/vespa/eval/tensor/tensor.h b/eval/src/vespa/eval/tensor/tensor.h index edf5fa710e3..d822c99a6d8 100644 --- a/eval/src/vespa/eval/tensor/tensor.h +++ b/eval/src/vespa/eval/tensor/tensor.h @@ -33,6 +33,7 @@ public: virtual ~Tensor() {} virtual Tensor::UP apply(const CellFunction &func) const = 0; virtual Tensor::UP join(join_fun_t function, const Tensor &arg) const = 0; + virtual Tensor::UP merge(join_fun_t function, const Tensor &arg) const = 0; virtual Tensor::UP reduce(join_fun_t op, const std::vector<vespalib::string> &dimensions) const = 0; /* diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp index a982a4b0fe1..7c09bc4e4ab 100644 --- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp +++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp @@ -76,6 +76,12 @@ WrappedSimpleTensor::join(join_fun_t, const Tensor &) const } Tensor::UP +WrappedSimpleTensor::merge(join_fun_t, const Tensor &) const +{ + LOG_ABORT("should not be reached"); +} + +Tensor::UP WrappedSimpleTensor::reduce(join_fun_t, const std::vector<vespalib::string> &) const { LOG_ABORT("should not be reached"); diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h index e7ffe7a755f..12ee1237d67 100644 --- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h +++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h @@ -37,6 +37,7 @@ public: // functions below should not be used for this implementation Tensor::UP apply(const CellFunction &) const override; Tensor::UP join(join_fun_t, const Tensor &) const override; + Tensor::UP merge(join_fun_t, const Tensor &) const override; Tensor::UP reduce(join_fun_t, const std::vector<vespalib::string> &) const override; std::unique_ptr<Tensor> modify(join_fun_t, const CellValues &) const override; std::unique_ptr<Tensor> add(const Tensor &arg) const override; |