summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2020-01-14 11:28:54 +0100
committerGitHub <noreply@github.com>2020-01-14 11:28:54 +0100
commitf9da0c56a59ab6c29c2f4dfbf350edd4347558ea (patch)
tree8edc0edf15703bcef1372509a737e8c93038a5aa
parent5a9dcffe32618e0b704996c5ddb33230cca6cbae (diff)
parentb84803f6ec11b4881da067cf2ef8ae35335886b7 (diff)
Merge pull request #11772 from vespa-engine/havardpe/tensor-merge
tensor merge
-rw-r--r--eval/src/tests/eval/compiled_function/compiled_function_test.cpp1
-rw-r--r--eval/src/tests/eval/node_types/node_types_test.cpp22
-rw-r--r--eval/src/tests/eval/simple_tensor/simple_tensor_test.cpp24
-rw-r--r--eval/src/tests/eval/tensor_function/tensor_function_test.cpp60
-rw-r--r--eval/src/tests/eval/value_type/value_type_test.cpp19
-rw-r--r--eval/src/vespa/eval/eval/function.cpp11
-rw-r--r--eval/src/vespa/eval/eval/interpreted_function.cpp3
-rw-r--r--eval/src/vespa/eval/eval/key_gen.cpp5
-rw-r--r--eval/src/vespa/eval/eval/llvm/compiled_function.cpp1
-rw-r--r--eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp3
-rw-r--r--eval/src/vespa/eval/eval/make_tensor_function.cpp12
-rw-r--r--eval/src/vespa/eval/eval/node_types.cpp4
-rw-r--r--eval/src/vespa/eval/eval/node_visitor.h4
-rw-r--r--eval/src/vespa/eval/eval/simple_tensor.cpp20
-rw-r--r--eval/src/vespa/eval/eval/simple_tensor.h1
-rw-r--r--eval/src/vespa/eval/eval/simple_tensor_engine.cpp6
-rw-r--r--eval/src/vespa/eval/eval/simple_tensor_engine.h1
-rw-r--r--eval/src/vespa/eval/eval/tensor_engine.h1
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.cpp24
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.h20
-rw-r--r--eval/src/vespa/eval/eval/tensor_nodes.cpp1
-rw-r--r--eval/src/vespa/eval/eval/tensor_nodes.h32
-rw-r--r--eval/src/vespa/eval/eval/test/eval_spec.cpp2
-rw-r--r--eval/src/vespa/eval/eval/test/tensor_conformance.cpp27
-rw-r--r--eval/src/vespa/eval/eval/test/tensor_model.hpp22
-rw-r--r--eval/src/vespa/eval/eval/value_type.cpp14
-rw-r--r--eval/src/vespa/eval/eval/value_type.h1
-rw-r--r--eval/src/vespa/eval/tensor/default_tensor_engine.cpp24
-rw-r--r--eval/src/vespa/eval/tensor/default_tensor_engine.h1
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp19
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_view.h1
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp25
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor.h1
-rw-r--r--eval/src/vespa/eval/tensor/tensor.h1
-rw-r--r--eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp6
-rw-r--r--eval/src/vespa/eval/tensor/wrapped_simple_tensor.h1
36 files changed, 403 insertions, 17 deletions
diff --git a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp
index b01c849da1e..a19bff415e1 100644
--- a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp
+++ b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp
@@ -51,6 +51,7 @@ TEST("require that lazy parameter passing works") {
std::vector<vespalib::string> unsupported = {
"map(",
"join(",
+ "merge(",
"reduce(",
"rename(",
"tensor(",
diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp
index dceaf279594..07bca5b53a6 100644
--- a/eval/src/tests/eval/node_types/node_types_test.cpp
+++ b/eval/src/tests/eval/node_types/node_types_test.cpp
@@ -195,6 +195,28 @@ TEST("require that join resolves correct type") {
TEST_DO(verify_op2("join(%s,%s,f(x,y)(x+y))"));
}
+TEST("require that merge resolves to the appropriate type") {
+ const char *pattern = "merge(%s,%s,f(x,y)(x+y))";
+ TEST_DO(verify(strfmt(pattern, "error", "error"), "error"));
+ TEST_DO(verify(strfmt(pattern, "double", "error"), "error"));
+ TEST_DO(verify(strfmt(pattern, "error", "double"), "error"));
+ TEST_DO(verify(strfmt(pattern, "tensor(x{})", "error"), "error"));
+ TEST_DO(verify(strfmt(pattern, "error", "tensor(x{})"), "error"));
+ TEST_DO(verify(strfmt(pattern, "double", "double"), "double"));
+ TEST_DO(verify(strfmt(pattern, "tensor(x{})", "double"), "error"));
+ TEST_DO(verify(strfmt(pattern, "double", "tensor(x{})"), "error"));
+ TEST_DO(verify(strfmt(pattern, "tensor(x{})", "tensor(x{})"), "tensor(x{})"));
+ TEST_DO(verify(strfmt(pattern, "tensor(x{})", "tensor(y{})"), "error"));
+ TEST_DO(verify(strfmt(pattern, "tensor(x[5])", "tensor(x[5])"), "tensor(x[5])"));
+ TEST_DO(verify(strfmt(pattern, "tensor(x[3])", "tensor(x[5])"), "error"));
+ TEST_DO(verify(strfmt(pattern, "tensor(x[5])", "tensor(x[3])"), "error"));
+ TEST_DO(verify(strfmt(pattern, "tensor(x{})", "tensor(x[5])"), "error"));
+ TEST_DO(verify(strfmt(pattern, "tensor<float>(x[5])", "tensor<float>(x[5])"), "tensor<float>(x[5])"));
+ TEST_DO(verify(strfmt(pattern, "tensor<float>(x[5])", "tensor(x[5])"), "tensor(x[5])"));
+ TEST_DO(verify(strfmt(pattern, "tensor(x[5])", "tensor<float>(x[5])"), "tensor(x[5])"));
+ TEST_DO(verify(strfmt(pattern, "tensor<float>(x[5])", "double"), "error"));
+}
+
TEST("require that lambda tensor resolves correct type") {
TEST_DO(verify("tensor(x[5])(1.0)", "tensor(x[5])"));
TEST_DO(verify("tensor(x[5],y[10])(1.0)", "tensor(x[5],y[10])"));
diff --git a/eval/src/tests/eval/simple_tensor/simple_tensor_test.cpp b/eval/src/tests/eval/simple_tensor/simple_tensor_test.cpp
index 9fb56288dda..f3bea3820c8 100644
--- a/eval/src/tests/eval/simple_tensor/simple_tensor_test.cpp
+++ b/eval/src/tests/eval/simple_tensor/simple_tensor_test.cpp
@@ -106,6 +106,30 @@ TEST("require that simple tensors can be multiplied with each other") {
EXPECT_EQUAL(to_spec(*expect), to_spec(unwrap(result2)));
}
+TEST("require that simple tensors can be merged") {
+ auto lhs = SimpleTensor::create(
+ TensorSpec("tensor(x{},y{})")
+ .add({{"x","1"},{"y","1"}}, 1)
+ .add({{"x","2"},{"y","1"}}, 3)
+ .add({{"x","1"},{"y","2"}}, 5));
+ auto rhs = SimpleTensor::create(
+ TensorSpec("tensor(x{},y{})")
+ .add({{"x","1"},{"y","2"}}, 7)
+ .add({{"x","2"},{"y","2"}}, 11)
+ .add({{"x","1"},{"y","1"}}, 13));
+ auto expect = SimpleTensor::create(
+ TensorSpec("tensor(x{},y{})")
+ .add({{"x","2"},{"y","1"}}, 3)
+ .add({{"x","1"},{"y","2"}}, 7)
+ .add({{"x","2"},{"y","2"}}, 11)
+ .add({{"x","1"},{"y","1"}}, 13));
+ auto result = SimpleTensor::merge(*lhs, *rhs, [](double, double b){ return b; });
+ EXPECT_EQUAL(to_spec(*expect), to_spec(*result));
+ Stash stash;
+ const Value &result2 = SimpleTensorEngine::ref().merge(*lhs, *rhs, [](double, double b){ return b; }, stash);
+ EXPECT_EQUAL(to_spec(*expect), to_spec(unwrap(result2)));
+}
+
TEST("require that simple tensors support dimension reduction") {
auto tensor = SimpleTensor::create(
TensorSpec("tensor(x[3],y[2])")
diff --git a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
index 1e59b35a01f..1eb9912abd2 100644
--- a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
+++ b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
@@ -154,6 +154,28 @@ struct EvalCtx {
.add({{"x","2"},{"y","1"},{"z","2"}}, 39)
.add({{"x","1"},{"y","2"},{"z","1"}}, 55));
}
+ Value::UP make_tensor_merge_lhs() {
+ return engine.from_spec(
+ TensorSpec("tensor(x{})")
+ .add({{"x","1"}}, 1)
+ .add({{"x","2"}}, 3)
+ .add({{"x","3"}}, 5));
+ }
+ Value::UP make_tensor_merge_rhs() {
+ return engine.from_spec(
+ TensorSpec("tensor(x{})")
+ .add({{"x","2"}}, 7)
+ .add({{"x","3"}}, 9)
+ .add({{"x","4"}}, 11));
+ }
+ Value::UP make_tensor_merge_output() {
+ return engine.from_spec(
+ TensorSpec("tensor(x{})")
+ .add({{"x","1"}}, 1)
+ .add({{"x","2"}}, 10)
+ .add({{"x","3"}}, 14)
+ .add({{"x","4"}}, 11));
+ }
};
void verify_equal(const Value &expect, const Value &value) {
@@ -237,6 +259,20 @@ TEST("require that tensor join works") {
TEST_DO(verify_equal(*expect, ctx.eval(prog)));
}
+TEST("require that tensor merge works") {
+ EvalCtx ctx(SimpleTensorEngine::ref());
+ size_t a_id = ctx.add_tensor(ctx.make_tensor_merge_lhs());
+ size_t b_id = ctx.add_tensor(ctx.make_tensor_merge_rhs());
+ Value::UP expect = ctx.make_tensor_merge_output();
+ const auto &fun = merge(inject(ValueType::from_spec("tensor(x{})"), a_id, ctx.stash),
+ inject(ValueType::from_spec("tensor(x{})"), b_id, ctx.stash),
+ operation::Add::f, ctx.stash);
+ EXPECT_TRUE(fun.result_is_mutable());
+ EXPECT_EQUAL(expect->type(), fun.result_type());
+ const auto &prog = ctx.compile(fun);
+ TEST_DO(verify_equal(*expect, ctx.eval(prog)));
+}
+
TEST("require that tensor concat works") {
EvalCtx ctx(SimpleTensorEngine::ref());
size_t a_id = ctx.add_tensor(ctx.make_tensor_matrix_first_half());
@@ -412,20 +448,25 @@ TEST("require that push_children works") {
EXPECT_EQUAL(&refs[2].get().get(), &a);
EXPECT_EQUAL(&refs[3].get().get(), &b);
//-------------------------------------------------------------------------
- concat(a, b, "x", stash).push_children(refs);
+ merge(a, b, operation::Add::f, stash).push_children(refs);
ASSERT_EQUAL(refs.size(), 6u);
EXPECT_EQUAL(&refs[4].get().get(), &a);
EXPECT_EQUAL(&refs[5].get().get(), &b);
//-------------------------------------------------------------------------
+ concat(a, b, "x", stash).push_children(refs);
+ ASSERT_EQUAL(refs.size(), 8u);
+ EXPECT_EQUAL(&refs[6].get().get(), &a);
+ EXPECT_EQUAL(&refs[7].get().get(), &b);
+ //-------------------------------------------------------------------------
rename(c, {}, {}, stash).push_children(refs);
- ASSERT_EQUAL(refs.size(), 7u);
- EXPECT_EQUAL(&refs[6].get().get(), &c);
+ ASSERT_EQUAL(refs.size(), 9u);
+ EXPECT_EQUAL(&refs[8].get().get(), &c);
//-------------------------------------------------------------------------
if_node(a, b, c, stash).push_children(refs);
- ASSERT_EQUAL(refs.size(), 10u);
- EXPECT_EQUAL(&refs[7].get().get(), &a);
- EXPECT_EQUAL(&refs[8].get().get(), &b);
- EXPECT_EQUAL(&refs[9].get().get(), &c);
+ ASSERT_EQUAL(refs.size(), 12u);
+ EXPECT_EQUAL(&refs[9].get().get(), &a);
+ EXPECT_EQUAL(&refs[10].get().get(), &b);
+ EXPECT_EQUAL(&refs[11].get().get(), &c);
//-------------------------------------------------------------------------
}
@@ -433,6 +474,7 @@ TEST("require that tensor function can be dumped for debugging") {
Stash stash;
auto my_value_1 = stash.create<DoubleValue>(1.0);
auto my_value_2 = stash.create<DoubleValue>(2.0);
+ auto my_value_3 = stash.create<DoubleValue>(3.0);
//-------------------------------------------------------------------------
const auto &x5 = inject(ValueType::from_spec("tensor(x[5])"), 0, stash);
const auto &mapped_x5 = map(x5, operation::Relu::f, stash);
@@ -452,7 +494,9 @@ TEST("require that tensor function can be dumped for debugging") {
const auto &concat_x5 = concat(x3, x2, "x", stash);
//-------------------------------------------------------------------------
const auto &const_2 = const_value(my_value_2, stash);
- const auto &root = if_node(const_2, joined_x5, concat_x5, stash);
+ const auto &const_3 = const_value(my_value_3, stash);
+ const auto &merged_double = merge(const_2, const_3, operation::Less::f, stash);
+ const auto &root = if_node(merged_double, joined_x5, concat_x5, stash);
EXPECT_EQUAL(root.result_type(), ValueType::from_spec("tensor(x[5])"));
fprintf(stderr, "function dump -->[[%s]]<-- function dump\n", root.as_string().c_str());
}
diff --git a/eval/src/tests/eval/value_type/value_type_test.cpp b/eval/src/tests/eval/value_type/value_type_test.cpp
index 85ff7613775..055c7e582bf 100644
--- a/eval/src/tests/eval/value_type/value_type_test.cpp
+++ b/eval/src/tests/eval/value_type/value_type_test.cpp
@@ -408,6 +408,25 @@ TEST("require that tensor dimensions can be renamed") {
EXPECT_EQUAL(type("error").rename({"a"}, {"b"}), type("error"));
}
+TEST("require that similar types can be merged") {
+ EXPECT_EQUAL(ValueType::merge(type("error"), type("error")), type("error"));
+ EXPECT_EQUAL(ValueType::merge(type("double"), type("double")), type("double"));
+ EXPECT_EQUAL(ValueType::merge(type("tensor(x[5])"), type("tensor(x[5])")), type("tensor(x[5])"));
+ EXPECT_EQUAL(ValueType::merge(type("tensor<float>(x[5])"), type("tensor(x[5])")), type("tensor(x[5])"));
+ EXPECT_EQUAL(ValueType::merge(type("tensor(x[5])"), type("tensor<float>(x[5])")), type("tensor(x[5])"));
+ EXPECT_EQUAL(ValueType::merge(type("tensor<float>(x[5])"), type("tensor<float>(x[5])")), type("tensor<float>(x[5])"));
+ EXPECT_EQUAL(ValueType::merge(type("tensor(x{})"), type("tensor(x{})")), type("tensor(x{})"));
+}
+
+TEST("require that diverging types can not be merged") {
+ EXPECT_EQUAL(ValueType::merge(type("error"), type("double")), type("error"));
+ EXPECT_EQUAL(ValueType::merge(type("double"), type("error")), type("error"));
+ EXPECT_EQUAL(ValueType::merge(type("tensor(x[5])"), type("double")), type("error"));
+ EXPECT_EQUAL(ValueType::merge(type("double"), type("tensor(x[5])")), type("error"));
+ EXPECT_EQUAL(ValueType::merge(type("tensor(x[5])"), type("tensor(x[3])")), type("error"));
+ EXPECT_EQUAL(ValueType::merge(type("tensor(x{})"), type("tensor(y{})")), type("error"));
+}
+
void verify_concat(const ValueType &a, const ValueType b, const vespalib::string &dim, const ValueType &res) {
EXPECT_EQUAL(ValueType::concat(a, b, dim), res);
EXPECT_EQUAL(ValueType::concat(b, a, dim), res);
diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp
index 03e0ae5c0a2..53107fefb32 100644
--- a/eval/src/vespa/eval/eval/function.cpp
+++ b/eval/src/vespa/eval/eval/function.cpp
@@ -603,6 +603,15 @@ void parse_tensor_join(ParseContext &ctx) {
ctx.push_expression(std::make_unique<nodes::TensorJoin>(std::move(lhs), std::move(rhs), std::move(lambda)));
}
+void parse_tensor_merge(ParseContext &ctx) {
+ Node_UP lhs = get_expression(ctx);
+ ctx.eat(',');
+ Node_UP rhs = get_expression(ctx);
+ ctx.eat(',');
+ auto lambda = parse_lambda(ctx, 2);
+ ctx.push_expression(std::make_unique<nodes::TensorMerge>(std::move(lhs), std::move(rhs), std::move(lambda)));
+}
+
void parse_tensor_reduce(ParseContext &ctx) {
Node_UP child = get_expression(ctx);
ctx.eat(',');
@@ -862,6 +871,8 @@ bool maybe_parse_call(ParseContext &ctx, const vespalib::string &name) {
parse_tensor_map(ctx);
} else if (name == "join") {
parse_tensor_join(ctx);
+ } else if (name == "merge") {
+ parse_tensor_merge(ctx);
} else if (name == "reduce") {
parse_tensor_reduce(ctx);
} else if (name == "rename") {
diff --git a/eval/src/vespa/eval/eval/interpreted_function.cpp b/eval/src/vespa/eval/eval/interpreted_function.cpp
index b93ad1eb10c..0a630a3e20a 100644
--- a/eval/src/vespa/eval/eval/interpreted_function.cpp
+++ b/eval/src/vespa/eval/eval/interpreted_function.cpp
@@ -28,6 +28,9 @@ const Function *get_lambda(const nodes::Node &node) {
if (auto ptr = nodes::as<nodes::TensorJoin>(node)) {
return &ptr->lambda();
}
+ if (auto ptr = nodes::as<nodes::TensorMerge>(node)) {
+ return &ptr->lambda();
+ }
return nullptr;
}
diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp
index 0f8382b5afc..46137b5878c 100644
--- a/eval/src/vespa/eval/eval/key_gen.cpp
+++ b/eval/src/vespa/eval/eval/key_gen.cpp
@@ -35,8 +35,9 @@ struct KeyGen : public NodeVisitor, public NodeTraverser {
void visit(const Not &) override { add_byte( 6); }
void visit(const If &node) override { add_byte( 7); add_double(node.p_true()); }
void visit(const Error &) override { add_byte( 9); }
- void visit(const TensorMap &) override { add_byte(11); } // lambda should be part of key
- void visit(const TensorJoin &) override { add_byte(12); } // lambda should be part of key
+ void visit(const TensorMap &) override { add_byte(10); } // lambda should be part of key
+ void visit(const TensorJoin &) override { add_byte(11); } // lambda should be part of key
+ void visit(const TensorMerge &) override { add_byte(12); } // lambda should be part of key
void visit(const TensorReduce &) override { add_byte(13); } // aggr/dimensions should be part of key
void visit(const TensorRename &) override { add_byte(14); } // dimensions should be part of key
void visit(const TensorConcat &) override { add_byte(15); } // dimension should be part of key
diff --git a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
index ca34c31e976..facfc502111 100644
--- a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
+++ b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
@@ -128,6 +128,7 @@ CompiledFunction::detect_issues(const Function &function)
void close(const nodes::Node &node) override {
if (nodes::check_type<nodes::TensorMap,
nodes::TensorJoin,
+ nodes::TensorMerge,
nodes::TensorReduce,
nodes::TensorRename,
nodes::TensorConcat,
diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
index 72406f09a0d..1d5515d7f4a 100644
--- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
+++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
@@ -467,6 +467,9 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
void visit(const TensorJoin &node) override {
make_error(node.num_children());
}
+ void visit(const TensorMerge &node) override {
+ make_error(node.num_children());
+ }
void visit(const TensorReduce &node) override {
make_error(node.num_children());
}
diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp
index 803047d27c4..849270b89a7 100644
--- a/eval/src/vespa/eval/eval/make_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp
@@ -77,6 +77,14 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
stack.back() = tensor_function::join(a, b, function, stash);
}
+ void make_merge(const Node &, join_fun_t function) {
+ assert(stack.size() >= 2);
+ const auto &b = stack.back();
+ stack.pop_back();
+ const auto &a = stack.back();
+ stack.back() = tensor_function::merge(a, b, function, stash);
+ }
+
void make_concat(const Node &, const vespalib::string &dimension) {
assert(stack.size() >= 2);
const auto &b = stack.back();
@@ -197,6 +205,10 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
make_join(node, token.get()->get().get_function<2>());
}
}
+ void visit(const TensorMerge &node) override {
+ const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE));
+ make_merge(node, token.get()->get().get_function<2>());
+ }
void visit(const TensorReduce &node) override {
make_reduce(node, node.aggr(), node.dimensions());
}
diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp
index 972afb7663f..bf87628e301 100644
--- a/eval/src/vespa/eval/eval/node_types.cpp
+++ b/eval/src/vespa/eval/eval/node_types.cpp
@@ -97,6 +97,10 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser {
}
void visit(const TensorMap &node) override { resolve_op1(node); }
void visit(const TensorJoin &node) override { resolve_op2(node); }
+ void visit(const TensorMerge &node) override {
+ bind(ValueType::merge(type(node.get_child(0)),
+ type(node.get_child(1))), node);
+ }
void visit(const TensorReduce &node) override {
const ValueType &child = type(node.get_child(0));
bind(child.reduce(node.dimensions()), node);
diff --git a/eval/src/vespa/eval/eval/node_visitor.h b/eval/src/vespa/eval/eval/node_visitor.h
index f32161a24c4..8f9722858b7 100644
--- a/eval/src/vespa/eval/eval/node_visitor.h
+++ b/eval/src/vespa/eval/eval/node_visitor.h
@@ -31,6 +31,7 @@ struct NodeVisitor {
// tensor nodes
virtual void visit(const nodes::TensorMap &) = 0;
virtual void visit(const nodes::TensorJoin &) = 0;
+ virtual void visit(const nodes::TensorMerge &) = 0;
virtual void visit(const nodes::TensorReduce &) = 0;
virtual void visit(const nodes::TensorRename &) = 0;
virtual void visit(const nodes::TensorConcat &) = 0;
@@ -99,7 +100,8 @@ struct EmptyNodeVisitor : NodeVisitor {
void visit(const nodes::If &) override {}
void visit(const nodes::Error &) override {}
void visit(const nodes::TensorMap &) override {}
- void visit(const nodes::TensorJoin &) override {}
+ void visit(const nodes::TensorJoin &) override {}
+ void visit(const nodes::TensorMerge &) override {}
void visit(const nodes::TensorReduce &) override {}
void visit(const nodes::TensorRename &) override {}
void visit(const nodes::TensorConcat &) override {}
diff --git a/eval/src/vespa/eval/eval/simple_tensor.cpp b/eval/src/vespa/eval/eval/simple_tensor.cpp
index b847d31335e..a72a24be211 100644
--- a/eval/src/vespa/eval/eval/simple_tensor.cpp
+++ b/eval/src/vespa/eval/eval/simple_tensor.cpp
@@ -3,6 +3,8 @@
#include "simple_tensor.h"
#include "simple_tensor_engine.h"
#include "operation.h"
+#include <vespa/vespalib/util/overload.h>
+#include <vespa/vespalib/util/visit_ranges.h>
#include <vespa/vespalib/objects/nbostream.h>
#include <algorithm>
#include <cassert>
@@ -678,6 +680,24 @@ SimpleTensor::join(const SimpleTensor &a, const SimpleTensor &b, join_fun_t func
}
std::unique_ptr<SimpleTensor>
+SimpleTensor::merge(const SimpleTensor &a, const SimpleTensor &b, join_fun_t function)
+{
+ ValueType result_type = ValueType::merge(a.type(), b.type());
+ if (result_type.is_error()) {
+ return std::make_unique<SimpleTensor>();
+ }
+ Builder builder(result_type);
+ auto cmp = [](const Cell &x, const Cell &y) { return (x.address < y.address); };
+ auto visitor = overload{
+ [&builder](visit_ranges_either, const Cell &x) { builder.set(x.address, x.value); },
+ [&builder,function](visit_ranges_both, const Cell &x, const Cell &y) {
+ builder.set(x.address, function(x.value, y.value));
+ }};
+ visit_ranges(visitor, a._cells.begin(), a._cells.end(), b._cells.begin(), b._cells.end(), cmp);
+ return builder.build();
+}
+
+std::unique_ptr<SimpleTensor>
SimpleTensor::concat(const SimpleTensor &a, const SimpleTensor &b, const vespalib::string &dimension)
{
ValueType result_type = ValueType::concat(a.type(), b.type(), dimension);
diff --git a/eval/src/vespa/eval/eval/simple_tensor.h b/eval/src/vespa/eval/eval/simple_tensor.h
index 45d1853824d..cbf1ac99e05 100644
--- a/eval/src/vespa/eval/eval/simple_tensor.h
+++ b/eval/src/vespa/eval/eval/simple_tensor.h
@@ -89,6 +89,7 @@ public:
std::unique_ptr<SimpleTensor> rename(const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to) const;
static std::unique_ptr<SimpleTensor> create(const TensorSpec &spec);
static std::unique_ptr<SimpleTensor> join(const SimpleTensor &a, const SimpleTensor &b, join_fun_t function);
+ static std::unique_ptr<SimpleTensor> merge(const SimpleTensor &a, const SimpleTensor &b, join_fun_t function);
static std::unique_ptr<SimpleTensor> concat(const SimpleTensor &a, const SimpleTensor &b, const vespalib::string &dimension);
static void encode(const SimpleTensor &tensor, nbostream &output);
static std::unique_ptr<SimpleTensor> decode(nbostream &input);
diff --git a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp
index 25fc98fc00a..6c2a0fcd53d 100644
--- a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp
+++ b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp
@@ -121,6 +121,12 @@ SimpleTensorEngine::join(const Value &a, const Value &b, join_fun_t function, St
}
const Value &
+SimpleTensorEngine::merge(const Value &a, const Value &b, join_fun_t function, Stash &stash) const
+{
+ return to_value(SimpleTensor::merge(to_simple(a, stash), to_simple(b, stash), function), stash);
+}
+
+const Value &
SimpleTensorEngine::reduce(const Value &a, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) const
{
return to_value(to_simple(a, stash).reduce(Aggregator::create(aggr, stash), dimensions), stash);
diff --git a/eval/src/vespa/eval/eval/simple_tensor_engine.h b/eval/src/vespa/eval/eval/simple_tensor_engine.h
index 645ec4c4be7..4c71e91c8d3 100644
--- a/eval/src/vespa/eval/eval/simple_tensor_engine.h
+++ b/eval/src/vespa/eval/eval/simple_tensor_engine.h
@@ -27,6 +27,7 @@ public:
const Value &map(const Value &a, map_fun_t function, Stash &stash) const override;
const Value &join(const Value &a, const Value &b, join_fun_t function, Stash &stash) const override;
+ const Value &merge(const Value &a, const Value &b, join_fun_t function, Stash &stash) const override;
const Value &reduce(const Value &a, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) const override;
const Value &concat(const Value &a, const Value &b, const vespalib::string &dimension, Stash &stash) const override;
const Value &rename(const Value &a, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash) const override;
diff --git a/eval/src/vespa/eval/eval/tensor_engine.h b/eval/src/vespa/eval/eval/tensor_engine.h
index 0ba25baed8c..f85f57cfa4f 100644
--- a/eval/src/vespa/eval/eval/tensor_engine.h
+++ b/eval/src/vespa/eval/eval/tensor_engine.h
@@ -49,6 +49,7 @@ struct TensorEngine
virtual const Value &map(const Value &a, map_fun_t function, Stash &stash) const = 0;
virtual const Value &join(const Value &a, const Value &b, join_fun_t function, Stash &stash) const = 0;
+ virtual const Value &merge(const Value &a, const Value &b, join_fun_t function, Stash &stash) const = 0;
virtual const Value &reduce(const Value &a, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) const = 0;
virtual const Value &concat(const Value &a, const Value &b, const vespalib::string &dimension, Stash &stash) const = 0;
virtual const Value &rename(const Value &a, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash) const = 0;
diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp
index 228a723e86a..45e8094570e 100644
--- a/eval/src/vespa/eval/eval/tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/tensor_function.cpp
@@ -101,6 +101,10 @@ void op_tensor_join(State &state, uint64_t param) {
state.pop_pop_push(state.engine.join(state.peek(1), state.peek(0), to_join_fun(param), state.stash));
}
+void op_tensor_merge(State &state, uint64_t param) {
+ state.pop_pop_push(state.engine.merge(state.peek(1), state.peek(0), to_join_fun(param), state.stash));
+}
+
using ReduceParams = std::pair<Aggr,std::vector<vespalib::string>>;
void op_tensor_reduce(State &state, uint64_t param) {
const ReduceParams &params = unwrap_param<ReduceParams>(param);
@@ -324,6 +328,21 @@ Join::visit_self(vespalib::ObjectVisitor &visitor) const
//-----------------------------------------------------------------------------
Instruction
+Merge::compile_self(Stash &) const
+{
+ return Instruction(op_tensor_merge, to_param(_function));
+}
+
+void
+Merge::visit_self(vespalib::ObjectVisitor &visitor) const
+{
+ Super::visit_self(visitor);
+ ::visit(visitor, "function", _function);
+}
+
+//-----------------------------------------------------------------------------
+
+Instruction
Concat::compile_self(Stash &) const
{
return Instruction(op_tensor_concat, wrap_param<vespalib::string>(_dimension));
@@ -471,6 +490,11 @@ const Node &join(const Node &lhs, const Node &rhs, join_fun_t function, Stash &s
return stash.create<Join>(result_type, lhs, rhs, function);
}
+const Node &merge(const Node &lhs, const Node &rhs, join_fun_t function, Stash &stash) {
+ ValueType result_type = ValueType::merge(lhs.result_type(), rhs.result_type());
+ return stash.create<Merge>(result_type, lhs, rhs, function);
+}
+
const Node &concat(const Node &lhs, const Node &rhs, const vespalib::string &dimension, Stash &stash) {
ValueType result_type = ValueType::concat(lhs.result_type(), rhs.result_type(), dimension);
return stash.create<Concat>(result_type, lhs, rhs, dimension);
diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h
index b019ab64e18..4b862e9ec6a 100644
--- a/eval/src/vespa/eval/eval/tensor_function.h
+++ b/eval/src/vespa/eval/eval/tensor_function.h
@@ -266,6 +266,25 @@ public:
//-----------------------------------------------------------------------------
+class Merge : public Op2
+{
+ using Super = Op2;
+private:
+ join_fun_t _function;
+public:
+ Merge(const ValueType &result_type_in,
+ const TensorFunction &lhs_in,
+ const TensorFunction &rhs_in,
+ join_fun_t function_in)
+ : Op2(result_type_in, lhs_in, rhs_in), _function(function_in) {}
+ join_fun_t function() const { return _function; }
+ bool result_is_mutable() const override { return true; }
+ InterpretedFunction::Instruction compile_self(Stash &stash) const override;
+ void visit_self(vespalib::ObjectVisitor &visitor) const override;
+};
+
+//-----------------------------------------------------------------------------
+
class Concat : public Op2
{
using Super = Op2;
@@ -394,6 +413,7 @@ const Node &inject(const ValueType &type, size_t param_idx, Stash &stash);
const Node &reduce(const Node &child, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash);
const Node &map(const Node &child, map_fun_t function, Stash &stash);
const Node &join(const Node &lhs, const Node &rhs, join_fun_t function, Stash &stash);
+const Node &merge(const Node &lhs, const Node &rhs, join_fun_t function, Stash &stash);
const Node &concat(const Node &lhs, const Node &rhs, const vespalib::string &dimension, Stash &stash);
const Node &create(const ValueType &type, const std::map<TensorSpec::Address, Node::CREF> &spec, Stash &stash);
const Node &peek(const Node &param, const std::map<vespalib::string, std::variant<TensorSpec::Label, Node::CREF>> &spec, Stash &stash);
diff --git a/eval/src/vespa/eval/eval/tensor_nodes.cpp b/eval/src/vespa/eval/eval/tensor_nodes.cpp
index e392248a08a..82d108300dd 100644
--- a/eval/src/vespa/eval/eval/tensor_nodes.cpp
+++ b/eval/src/vespa/eval/eval/tensor_nodes.cpp
@@ -9,6 +9,7 @@ namespace nodes {
void TensorMap ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void TensorJoin ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorMerge ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void TensorReduce::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void TensorRename::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void TensorConcat::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
diff --git a/eval/src/vespa/eval/eval/tensor_nodes.h b/eval/src/vespa/eval/eval/tensor_nodes.h
index 4213809f9e3..daba46c1fc5 100644
--- a/eval/src/vespa/eval/eval/tensor_nodes.h
+++ b/eval/src/vespa/eval/eval/tensor_nodes.h
@@ -76,6 +76,38 @@ public:
}
};
+class TensorMerge : public Node {
+private:
+ Node_UP _lhs;
+ Node_UP _rhs;
+ std::shared_ptr<Function const> _lambda;
+public:
+ TensorMerge(Node_UP lhs, Node_UP rhs, std::shared_ptr<Function const> lambda)
+ : _lhs(std::move(lhs)), _rhs(std::move(rhs)), _lambda(std::move(lambda)) {}
+ const Function &lambda() const { return *_lambda; }
+ vespalib::string dump(DumpContext &ctx) const override {
+ vespalib::string str;
+ str += "join(";
+ str += _lhs->dump(ctx);
+ str += ",";
+ str += _rhs->dump(ctx);
+ str += ",";
+ str += _lambda->dump_as_lambda();
+ str += ")";
+ return str;
+ }
+ void accept(NodeVisitor &visitor) const override ;
+ size_t num_children() const override { return 2; }
+ const Node &get_child(size_t idx) const override {
+ assert(idx < 2);
+ return (idx == 0) ? *_lhs : *_rhs;
+ }
+ void detach_children(NodeHandler &handler) override {
+ handler.handle(std::move(_lhs));
+ handler.handle(std::move(_rhs));
+ }
+};
+
class TensorReduce : public Node {
private:
Node_UP _child;
diff --git a/eval/src/vespa/eval/eval/test/eval_spec.cpp b/eval/src/vespa/eval/eval/test/eval_spec.cpp
index dc1e3405a6b..709234a1a2c 100644
--- a/eval/src/vespa/eval/eval/test/eval_spec.cpp
+++ b/eval/src/vespa/eval/eval/test/eval_spec.cpp
@@ -165,6 +165,8 @@ EvalSpec::add_tensor_operation_cases() {
add_rule({"a", -1.0, 1.0}, "map(a,f(x)(x+x*3))", [](double x){ return (x + (x * 3)); });
add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "join(a,b,f(x,y)(x+y))", [](double x, double y){ return (x + y); });
add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "join(a,b,f(x,y)(x+y*3))", [](double x, double y){ return (x + (y * 3)); });
+ add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "merge(a,b,f(x,y)(x+y))", [](double x, double y){ return (x + y); });
+ add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "merge(a,b,f(x,y)(x+y*3))", [](double x, double y){ return (x + (y * 3)); });
add_rule({"a", -1.0, 1.0}, "reduce(a,avg)", [](double a){ return a; });
add_rule({"a", -1.0, 1.0}, "reduce(a,count)", [](double){ return 1.0; });
add_rule({"a", -1.0, 1.0}, "reduce(a,prod)", [](double a){ return a; });
diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
index 9242b19310d..41c6dd21e24 100644
--- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
+++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
@@ -829,6 +829,32 @@ struct TestContext {
//-------------------------------------------------------------------------
+ void test_tensor_merge(const vespalib::string &type_base, const vespalib::string &a_str,
+ const vespalib::string &b_str, const vespalib::string &expect_str)
+ {
+ vespalib::string expr = "merge(a,b,f(x,y)(2*x+y))";
+ for (bool a_float: {false, true}) {
+ for (bool b_float: {false, true}) {
+ bool both_float = a_float && b_float;
+ vespalib::string a_expr = make_string("tensor%s(%s):%s", a_float ? "<float>" : "", type_base.c_str(), a_str.c_str());
+ vespalib::string b_expr = make_string("tensor%s(%s):%s", b_float ? "<float>" : "", type_base.c_str(), b_str.c_str());
+ vespalib::string expect_expr = make_string("tensor%s(%s):%s", both_float ? "<float>" : "", type_base.c_str(), expect_str.c_str());
+ TensorSpec a = spec(a_expr);
+ TensorSpec b = spec(b_expr);
+ TensorSpec expect = spec(expect_expr);
+ TEST_DO(verify_result(Expr_TT(expr).eval(engine, a, b), expect));
+ }
+ }
+ }
+
+ void test_tensor_merge() {
+ TEST_DO(test_tensor_merge("x[3]", "[1,2,3]", "[4,5,6]", "[6,9,12]"));
+ TEST_DO(test_tensor_merge("x{}", "{a:1,b:2,c:3}", "{b:4,c:5,d:6}", "{a:1,b:8,c:11,d:6}"));
+ TEST_DO(test_tensor_merge("x{},y[2]", "{a:[1,2],b:[3,4]}", "{b:[5,6],c:[6,7]}", "{a:[1,2],b:[11,14],c:[6,7]}"));
+ }
+
+ //-------------------------------------------------------------------------
+
void verify_encode_decode(const TensorSpec &spec,
const TensorEngine &encode_engine,
const TensorEngine &decode_engine)
@@ -913,6 +939,7 @@ struct TestContext {
TEST_DO(test_tensor_lambda());
TEST_DO(test_tensor_create());
TEST_DO(test_tensor_peek());
+ TEST_DO(test_tensor_merge());
TEST_DO(test_binary_format());
}
};
diff --git a/eval/src/vespa/eval/eval/test/tensor_model.hpp b/eval/src/vespa/eval/eval/test/tensor_model.hpp
index 6efb7470d55..2466701df62 100644
--- a/eval/src/vespa/eval/eval/test/tensor_model.hpp
+++ b/eval/src/vespa/eval/eval/test/tensor_model.hpp
@@ -5,6 +5,10 @@
#include <vespa/eval/eval/value_type.h>
#include <vespa/eval/eval/operation.h>
#include <vespa/eval/eval/tensor_engine.h>
+#include <vespa/eval/eval/function.h>
+#include <vespa/eval/eval/node_types.h>
+#include <vespa/eval/eval/interpreted_function.h>
+#include <vespa/eval/eval/simple_tensor_engine.h>
namespace vespalib {
namespace eval {
@@ -292,6 +296,24 @@ TensorSpec spec(const vespalib::string &type,
return spec;
}
+TensorSpec spec(const vespalib::string &value_expr) {
+ if (value_expr == "error") {
+ return TensorSpec("error");
+ }
+ const auto &engine = SimpleTensorEngine::ref();
+ auto fun = Function::parse(value_expr);
+ ASSERT_TRUE(!fun->has_error());
+ ASSERT_EQUAL(fun->num_params(), 0u);
+ NodeTypes types(*fun, {});
+ ASSERT_TRUE(!types.get_type(fun->root()).is_error());
+ InterpretedFunction ifun(engine, *fun, types);
+ InterpretedFunction::Context ctx(ifun);
+ SimpleObjectParams params({});
+ auto result = engine.to_spec(ifun.eval(ctx, params));
+ ASSERT_TRUE(!result.cells().empty());
+ return result;
+}
+
} // namespace vespalib::eval::test
} // namespace vespalib::eval
} // namespace vespalib
diff --git a/eval/src/vespa/eval/eval/value_type.cpp b/eval/src/vespa/eval/eval/value_type.cpp
index 211a9c305a3..36c7c49c8b9 100644
--- a/eval/src/vespa/eval/eval/value_type.cpp
+++ b/eval/src/vespa/eval/eval/value_type.cpp
@@ -289,6 +289,20 @@ ValueType::join(const ValueType &lhs, const ValueType &rhs)
return tensor_type(std::move(result.dimensions), unify(lhs._cell_type, rhs._cell_type));
}
+ValueType
+ValueType::merge(const ValueType &lhs, const ValueType &rhs)
+{
+ if ((lhs.type() != rhs.type()) ||
+ (lhs.dimensions() != rhs.dimensions()))
+ {
+ return error_type();
+ }
+ if (lhs.dimensions().empty()) {
+ return lhs;
+ }
+ return tensor_type(lhs.dimensions(), unify(lhs._cell_type, rhs._cell_type));
+}
+
CellType
ValueType::unify_cell_types(const ValueType &a, const ValueType &b) {
if (a.is_double()) {
diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h
index b02053be3cb..1a4182dff2a 100644
--- a/eval/src/vespa/eval/eval/value_type.h
+++ b/eval/src/vespa/eval/eval/value_type.h
@@ -80,6 +80,7 @@ public:
static ValueType from_spec(const vespalib::string &spec, std::vector<ValueType::Dimension> &unsorted);
vespalib::string to_spec() const;
static ValueType join(const ValueType &lhs, const ValueType &rhs);
+ static ValueType merge(const ValueType &lhs, const ValueType &rhs);
static CellType unify_cell_types(const ValueType &a, const ValueType &b);
static ValueType concat(const ValueType &lhs, const ValueType &rhs, const vespalib::string &dimension);
static ValueType either(const ValueType &one, const ValueType &other);
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
index b310ad72fc1..b4449309812 100644
--- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
+++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
@@ -84,6 +84,7 @@ const Value &to_default(const Value &value, Stash &stash) {
}
const Value &to_value(std::unique_ptr<Tensor> tensor, Stash &stash) {
+ assert(tensor);
if (tensor->type().is_tensor()) {
return *stash.create<Value::UP>(std::move(tensor));
}
@@ -101,6 +102,10 @@ const Value &fallback_join(const Value &a, const Value &b, join_fun_t function,
return to_default(simple_engine().join(to_simple(a, stash), to_simple(b, stash), function, stash), stash);
}
+const Value &fallback_merge(const Value &a, const Value &b, join_fun_t function, Stash &stash) {
+ return to_default(simple_engine().merge(to_simple(a, stash), to_simple(b, stash), function, stash), stash);
+}
+
const Value &fallback_reduce(const Value &a, eval::Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) {
return to_default(simple_engine().reduce(to_simple(a, stash), aggr, dimensions, stash), stash);
}
@@ -329,6 +334,25 @@ DefaultTensorEngine::join(const Value &a, const Value &b, join_fun_t function, S
}
const Value &
+DefaultTensorEngine::merge(const Value &a, const Value &b, join_fun_t function, Stash &stash) const
+{
+ if (auto tensor_a = a.as_tensor()) {
+ auto tensor_b = b.as_tensor();
+ assert(tensor_b);
+ assert(&tensor_a->engine() == this);
+ assert(&tensor_b->engine() == this);
+ const tensor::Tensor &my_a = static_cast<const tensor::Tensor &>(*tensor_a);
+ const tensor::Tensor &my_b = static_cast<const tensor::Tensor &>(*tensor_b);
+ if (!tensor::Tensor::supported({my_a.type(), my_b.type()})) {
+ return fallback_merge(a, b, function, stash);
+ }
+ return to_value(my_a.merge(function, my_b), stash);
+ } else {
+ return stash.create<DoubleValue>(function(a.as_double(), b.as_double()));
+ }
+}
+
+const Value &
DefaultTensorEngine::reduce(const Value &a, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) const
{
if (auto tensor = a.as_tensor()) {
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.h b/eval/src/vespa/eval/tensor/default_tensor_engine.h
index 29f9811e170..5c39706b326 100644
--- a/eval/src/vespa/eval/tensor/default_tensor_engine.h
+++ b/eval/src/vespa/eval/tensor/default_tensor_engine.h
@@ -28,6 +28,7 @@ public:
const Value &map(const Value &a, map_fun_t function, Stash &stash) const override;
const Value &join(const Value &a, const Value &b, join_fun_t function, Stash &stash) const override;
+ const Value &merge(const Value &a, const Value &b, join_fun_t function, Stash &stash) const override;
const Value &reduce(const Value &a, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) const override;
const Value &concat(const Value &a, const Value &b, const vespalib::string &dimension, Stash &stash) const override;
const Value &rename(const Value &a, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash) const override;
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
index 3fed84323ca..4ed6758dfde 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
@@ -90,7 +90,7 @@ checkDimensions(const DenseTensorView &lhs, const DenseTensorView &rhs,
template <typename LCT, typename RCT, typename Function>
static Tensor::UP
sameShapeJoin(const ConstArrayRef<LCT> &lhs, const ConstArrayRef<RCT> &rhs,
- const eval::ValueType &lhs_type,
+ const std::vector<eval::ValueType::Dimension> &lhs_dims,
Function &&func)
{
size_t sz = lhs.size();
@@ -106,7 +106,7 @@ sameShapeJoin(const ConstArrayRef<LCT> &lhs, const ConstArrayRef<RCT> &rhs,
}
assert(rhsCellItr == rhs.cend());
assert(newCells.size() == sz);
- auto newType = eval::ValueType::tensor_type(lhs_type.dimensions(), eval::get_cell_type<OCT>());
+ auto newType = eval::ValueType::tensor_type(lhs_dims, eval::get_cell_type<OCT>());
return std::make_unique<DenseTensor<OCT>>(std::move(newType), std::move(newCells));
}
@@ -115,10 +115,10 @@ struct CallJoin
template <typename LCT, typename RCT, typename Function>
static Tensor::UP
call(const ConstArrayRef<LCT> &lhs, const ConstArrayRef<RCT> &rhs,
- const eval::ValueType &lhs_type,
+ const std::vector<eval::ValueType::Dimension> &lhs_dims,
Function &&func)
{
- return sameShapeJoin(lhs, rhs, lhs_type, std::move(func));
+ return sameShapeJoin(lhs, rhs, lhs_dims, std::move(func));
}
};
@@ -129,7 +129,7 @@ joinDenseTensors(const DenseTensorView &lhs, const DenseTensorView &rhs,
{
TypedCells lhsCells = lhs.cellsRef();
TypedCells rhsCells = rhs.cellsRef();
- return dispatch_2<CallJoin>(lhsCells, rhsCells, lhs.fast_type(), std::move(func));
+ return dispatch_2<CallJoin>(lhsCells, rhsCells, lhs.fast_type().dimensions(), std::move(func));
}
template <typename Function>
@@ -289,7 +289,7 @@ DenseTensorView::accept(TensorVisitor &visitor) const
Tensor::UP
DenseTensorView::join(join_fun_t function, const Tensor &arg) const
{
- if (fast_type() == arg.type()) {
+ if (fast_type().dimensions() == arg.type().dimensions()) {
if (function == eval::operation::Mul::f) {
return joinDenseTensors(*this, arg, "mul",
[](double a, double b) { return (a * b); });
@@ -310,6 +310,13 @@ DenseTensorView::join(join_fun_t function, const Tensor &arg) const
}
Tensor::UP
+DenseTensorView::merge(join_fun_t function, const Tensor &arg) const
+{
+ assert(fast_type().dimensions() == arg.type().dimensions());
+ return join(function, arg);
+}
+
+Tensor::UP
DenseTensorView::reduce_all(join_fun_t op, const std::vector<vespalib::string> &dims) const
{
if (op == eval::operation::Mul::f) {
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
index 778f2aa2871..33183d267c1 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
@@ -34,6 +34,7 @@ public:
double as_double() const override;
Tensor::UP apply(const CellFunction &func) const override;
Tensor::UP join(join_fun_t function, const Tensor &arg) const override;
+ Tensor::UP merge(join_fun_t function, const Tensor &arg) const override;
Tensor::UP reduce(join_fun_t op, const std::vector<vespalib::string> &dimensions) const override;
std::unique_ptr<Tensor> modify(join_fun_t op, const CellValues &cellValues) const override;
std::unique_ptr<Tensor> add(const Tensor &arg) const override;
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
index bcfbc851e6d..1fc93e8234f 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
@@ -8,6 +8,7 @@
#include "sparse_tensor_modify.h"
#include "sparse_tensor_reduce.hpp"
#include "sparse_tensor_remove.h"
+#include "direct_sparse_tensor_builder.h"
#include <vespa/eval/eval/operation.h>
#include <vespa/eval/tensor/cell_values.h>
#include <vespa/eval/tensor/tensor_address_builder.h>
@@ -175,6 +176,30 @@ SparseTensor::join(join_fun_t function, const Tensor &arg) const
}
Tensor::UP
+SparseTensor::merge(join_fun_t function, const Tensor &arg) const
+{
+ const SparseTensor *rhs = dynamic_cast<const SparseTensor *>(&arg);
+ assert(rhs && (fast_type().dimensions() == rhs->fast_type().dimensions()));
+ DirectSparseTensorBuilder builder(eval::ValueType::merge(fast_type(), rhs->fast_type()));
+ builder.reserve(cells().size() + rhs->cells().size());
+ for (const auto &cell: cells()) {
+ auto pos = rhs->cells().find(cell.first);
+ if (pos == rhs->cells().end()) {
+ builder.insertCell(cell.first, cell.second);
+ } else {
+ builder.insertCell(cell.first, function(cell.second, pos->second));
+ }
+ }
+ for (const auto &cell: rhs->cells()) {
+ auto pos = cells().find(cell.first);
+ if (pos == cells().end()) {
+ builder.insertCell(cell.first, cell.second);
+ }
+ }
+ return builder.build();
+}
+
+Tensor::UP
SparseTensor::reduce(join_fun_t op,
const std::vector<vespalib::string> &dimensions) const
{
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
index c182c09c6b0..880cd32c605 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
@@ -44,6 +44,7 @@ public:
double as_double() const override;
Tensor::UP apply(const CellFunction &func) const override;
Tensor::UP join(join_fun_t function, const Tensor &arg) const override;
+ Tensor::UP merge(join_fun_t function, const Tensor &arg) const override;
Tensor::UP reduce(join_fun_t op, const std::vector<vespalib::string> &dimensions) const override;
std::unique_ptr<Tensor> modify(join_fun_t op, const CellValues &cellValues) const override;
std::unique_ptr<Tensor> add(const Tensor &arg) const override;
diff --git a/eval/src/vespa/eval/tensor/tensor.h b/eval/src/vespa/eval/tensor/tensor.h
index edf5fa710e3..d822c99a6d8 100644
--- a/eval/src/vespa/eval/tensor/tensor.h
+++ b/eval/src/vespa/eval/tensor/tensor.h
@@ -33,6 +33,7 @@ public:
virtual ~Tensor() {}
virtual Tensor::UP apply(const CellFunction &func) const = 0;
virtual Tensor::UP join(join_fun_t function, const Tensor &arg) const = 0;
+ virtual Tensor::UP merge(join_fun_t function, const Tensor &arg) const = 0;
virtual Tensor::UP reduce(join_fun_t op, const std::vector<vespalib::string> &dimensions) const = 0;
/*
diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
index a982a4b0fe1..7c09bc4e4ab 100644
--- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
+++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
@@ -76,6 +76,12 @@ WrappedSimpleTensor::join(join_fun_t, const Tensor &) const
}
Tensor::UP
+WrappedSimpleTensor::merge(join_fun_t, const Tensor &) const
+{
+ LOG_ABORT("should not be reached");
+}
+
+Tensor::UP
WrappedSimpleTensor::reduce(join_fun_t, const std::vector<vespalib::string> &) const
{
LOG_ABORT("should not be reached");
diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h
index e7ffe7a755f..12ee1237d67 100644
--- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h
+++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h
@@ -37,6 +37,7 @@ public:
// functions below should not be used for this implementation
Tensor::UP apply(const CellFunction &) const override;
Tensor::UP join(join_fun_t, const Tensor &) const override;
+ Tensor::UP merge(join_fun_t, const Tensor &) const override;
Tensor::UP reduce(join_fun_t, const std::vector<vespalib::string> &) const override;
std::unique_ptr<Tensor> modify(join_fun_t, const CellValues &) const override;
std::unique_ptr<Tensor> add(const Tensor &arg) const override;