summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@yahooinc.com>2023-10-26 08:34:13 +0000
committerHåvard Pettersen <havardpe@yahooinc.com>2023-11-01 15:20:13 +0000
commitc70562eb766a5205fba797f62456652919e7cd3d (patch)
treeec0b81cc3ff37c489db3931743e0cc16b0c2cd66 /eval
parente8c83ca7ee15b11768e9a9b50aea479103139277 (diff)
map_subspaces operation
Diffstat (limited to 'eval')
-rw-r--r--eval/CMakeLists.txt1
-rw-r--r--eval/src/tests/eval/map_subspaces/CMakeLists.txt8
-rw-r--r--eval/src/tests/eval/map_subspaces/map_subspaces_test.cpp103
-rw-r--r--eval/src/tests/eval/node_types/node_types_test.cpp36
-rw-r--r--eval/src/tests/eval/reference_evaluation/reference_evaluation_test.cpp6
-rw-r--r--eval/src/tests/eval/reference_operations/reference_operations_test.cpp62
-rw-r--r--eval/src/tests/eval/value_type/value_type_test.cpp30
-rw-r--r--eval/src/vespa/eval/eval/cell_type.h3
-rw-r--r--eval/src/vespa/eval/eval/function.cpp9
-rw-r--r--eval/src/vespa/eval/eval/key_gen.cpp117
-rw-r--r--eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp3
-rw-r--r--eval/src/vespa/eval/eval/make_tensor_function.cpp9
-rw-r--r--eval/src/vespa/eval/eval/node_tools.cpp109
-rw-r--r--eval/src/vespa/eval/eval/node_types.cpp26
-rw-r--r--eval/src/vespa/eval/eval/node_visitor.h250
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.cpp20
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.h24
-rw-r--r--eval/src/vespa/eval/eval/tensor_nodes.cpp21
-rw-r--r--eval/src/vespa/eval/eval/tensor_nodes.h30
-rw-r--r--eval/src/vespa/eval/eval/tensor_spec.cpp22
-rw-r--r--eval/src/vespa/eval/eval/test/reference_evaluation.cpp10
-rw-r--r--eval/src/vespa/eval/eval/test/reference_operations.cpp46
-rw-r--r--eval/src/vespa/eval/eval/test/reference_operations.h2
-rw-r--r--eval/src/vespa/eval/eval/value_type.cpp68
-rw-r--r--eval/src/vespa/eval/eval/value_type.h3
-rw-r--r--eval/src/vespa/eval/instruction/CMakeLists.txt1
-rw-r--r--eval/src/vespa/eval/instruction/generic_map_subspaces.cpp118
-rw-r--r--eval/src/vespa/eval/instruction/generic_map_subspaces.h17
28 files changed, 879 insertions, 275 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt
index fac6691bbff..11d5ecddaf5 100644
--- a/eval/CMakeLists.txt
+++ b/eval/CMakeLists.txt
@@ -29,6 +29,7 @@ vespa_define_module(
src/tests/eval/int8float
src/tests/eval/interpreted_function
src/tests/eval/llvm_stress
+ src/tests/eval/map_subspaces
src/tests/eval/multiply_add
src/tests/eval/nested_loop
src/tests/eval/node_tools
diff --git a/eval/src/tests/eval/map_subspaces/CMakeLists.txt b/eval/src/tests/eval/map_subspaces/CMakeLists.txt
new file mode 100644
index 00000000000..90b2ce07791
--- /dev/null
+++ b/eval/src/tests/eval/map_subspaces/CMakeLists.txt
@@ -0,0 +1,8 @@
+# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(eval_map_subspaces_test_app TEST
+ SOURCES
+ map_subspaces_test.cpp
+ DEPENDS
+ vespaeval
+)
+vespa_add_test(NAME eval_map_subspaces_test_app COMMAND eval_map_subspaces_test_app)
diff --git a/eval/src/tests/eval/map_subspaces/map_subspaces_test.cpp b/eval/src/tests/eval/map_subspaces/map_subspaces_test.cpp
new file mode 100644
index 00000000000..278d49992be
--- /dev/null
+++ b/eval/src/tests/eval/map_subspaces/map_subspaces_test.cpp
@@ -0,0 +1,103 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/vespalib/testkit/test_kit.h>
+#include <vespa/eval/eval/tensor_function.h>
+#include <vespa/eval/eval/simple_value.h>
+#include <vespa/eval/eval/fast_value.h>
+#include <vespa/eval/eval/test/gen_spec.h>
+#include <vespa/eval/eval/test/eval_fixture.h>
+#include <vespa/eval/eval/tensor_nodes.h>
+
+#include <vespa/vespalib/util/stringfmt.h>
+#include <vespa/vespalib/util/stash.h>
+
+using namespace vespalib;
+using namespace vespalib::eval;
+using namespace vespalib::eval::test;
+using namespace vespalib::eval::tensor_function;
+
+void verify(const vespalib::string &a, const vespalib::string &expr, const vespalib::string &result) {
+ EvalFixture::ParamRepo param_repo;
+ param_repo.add("a", TensorSpec::from_expr(a));
+ auto expect = TensorSpec::from_expr(result);
+ EXPECT_FALSE(ValueType::from_spec(expect.type()).is_error());
+ EXPECT_EQUAL(EvalFixture::ref(expr, param_repo), expect);
+ EXPECT_EQUAL(EvalFixture::prod(expr, param_repo), expect);
+}
+
+//-----------------------------------------------------------------------------
+
+TEST("require that simple map_subspaces work") {
+ TEST_DO(verify("tensor(x{},y[3]):{foo:[1,2,3],bar:[4,5,6]}",
+ "map_subspaces(a,f(t)(tensor(y[2])(t{y:(y)}+t{y:(y+1)})))",
+ "tensor(x{},y[2]):{foo:[3,5],bar:[9,11]}"));
+}
+
+TEST("require that scalars can be used with map_subspaces") {
+ TEST_DO(verify("3.0",
+ "map_subspaces(a,f(n)(n+5.0))",
+ "8.0"));
+}
+
+TEST("require that outer cell type is decayed when inner type is double") {
+ TEST_DO(verify("tensor<int8>(x{}):{foo:3,bar:7}",
+ "map_subspaces(a,f(n)(n+2))",
+ "tensor<float>(x{}):{foo:5,bar:9}"));
+}
+
+TEST("require that inner cell type is used directly without decay") {
+ TEST_DO(verify("tensor(x{},y[3]):{foo:[1,2,3],bar:[4,5,6]}",
+ "map_subspaces(a,f(t)(cell_cast(t,int8)))",
+ "tensor<int8>(x{},y[3]):{foo:[1,2,3],bar:[4,5,6]}"));
+ TEST_DO(verify("tensor(y[3]):[1,2,3]",
+ "map_subspaces(a,f(t)(cell_cast(t,int8)))",
+ "tensor<int8>(y[3]):[1,2,3]"));
+}
+
+TEST("require that map_subspaces can be nested") {
+ TEST_DO(verify("tensor(x{},y[3]):{foo:[1,2,3],bar:[4,5,6]}",
+ "map_subspaces(a,f(a)(5+map_subspaces(a,f(t)(tensor(y[2])(t{y:(y)}+t{y:(y+1)})))))",
+ "tensor(x{},y[2]):{foo:[8,10],bar:[14,16]}"));
+}
+
+size_t count_nodes(const NodeTypes &types) {
+ size_t cnt = 0;
+ types.each([&](const auto &, const auto &){++cnt;});
+ return cnt;
+}
+
+void check_errors(const NodeTypes &types) {
+ for (const auto &err: types.errors()) {
+ fprintf(stderr, "%s\n", err.c_str());
+ }
+ ASSERT_EQUAL(types.errors().size(), 0u);
+}
+
+TEST("require that type resolving also include nodes from the mapping lambda function") {
+ auto fun = Function::parse("map_subspaces(a,f(a)(map_subspaces(a,f(t)(tensor(y[2])(t{y:(y)}+t{y:(y+1)})))))");
+ NodeTypes types(*fun, {ValueType::from_spec("tensor(x{},y[3])")});
+ check_errors(types);
+ auto map_subspaces = nodes::as<nodes::TensorMapSubspaces>(fun->root());
+ ASSERT_TRUE(map_subspaces != nullptr);
+ EXPECT_EQUAL(types.get_type(*map_subspaces).to_spec(), "tensor(x{},y[2])");
+ EXPECT_EQUAL(types.get_type(map_subspaces->lambda().root()).to_spec(), "tensor(y[2])");
+
+ NodeTypes copy = types.export_types(fun->root());
+ check_errors(copy);
+ EXPECT_EQUAL(count_nodes(types), count_nodes(copy));
+
+ NodeTypes map_types = copy.export_types(map_subspaces->lambda().root());
+ check_errors(map_types);
+ EXPECT_LESS(count_nodes(map_types), count_nodes(copy));
+
+ auto inner_map = nodes::as<nodes::TensorMapSubspaces>(map_subspaces->lambda().root());
+ ASSERT_TRUE(inner_map != nullptr);
+ NodeTypes inner_types = map_types.export_types(inner_map->lambda().root());
+ check_errors(inner_types);
+ EXPECT_LESS(count_nodes(inner_types), count_nodes(map_types));
+
+ // [lambda, peek, t, y, +, peek, t, y, +, 1] are the 10 nodes:
+ EXPECT_EQUAL(count_nodes(inner_types), 10u);
+}
+
+TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp
index d58f9fda943..ad5426e2a99 100644
--- a/eval/src/tests/eval/node_types/node_types_test.cpp
+++ b/eval/src/tests/eval/node_types/node_types_test.cpp
@@ -354,6 +354,42 @@ TEST("require that tensor cell_cast resolves correct type") {
TEST_DO(verify("cell_cast(tensor<float>(x{},y[5]),int8)", "tensor<int8>(x{},y[5])"));
}
+TEST("require that tensor map_subspace resolves correct type") {
+ // double input
+ TEST_DO(verify("map_subspaces(double, f(a)(a))", "double"));
+ TEST_DO(verify("map_subspaces(double, f(a)(tensor<int8>(y[2]):[a,a]))", "tensor<int8>(y[2])"));
+
+ // sparse input
+ TEST_DO(verify("map_subspaces(tensor<float>(x{}), f(a)(a))", "tensor<float>(x{})"));
+ TEST_DO(verify("map_subspaces(tensor<int8>(x{}), f(a)(a))", "tensor<float>(x{})")); // NB: decay
+ TEST_DO(verify("map_subspaces(tensor<float>(x{}), f(a)(tensor<int8>(y[2]):[a,a]))", "tensor<int8>(x{},y[2])"));
+
+ // dense input
+ TEST_DO(verify("map_subspaces(tensor<float>(y[10]), f(a)(a))", "tensor<float>(y[10])"));
+ TEST_DO(verify("map_subspaces(tensor<int8>(y[10]), f(a)(a))", "tensor<int8>(y[10])")); // NB: no decay
+ TEST_DO(verify("map_subspaces(tensor<float>(y[10]), f(a)(reduce(a,sum)))", "double"));
+ TEST_DO(verify("map_subspaces(tensor<float>(y[10]), f(a)(cell_cast(a,int8)))", "tensor<int8>(y[10])"));
+ TEST_DO(verify("map_subspaces(tensor<int8>(y[10]), f(a)(a*tensor<int8>(z[2]):[a{y:0},a{y:1}]))", "tensor<float>(y[10],z[2])"));
+
+ // mixed input
+ TEST_DO(verify("map_subspaces(tensor<float>(x{},y[10]), f(a)(a))", "tensor<float>(x{},y[10])"));
+ TEST_DO(verify("map_subspaces(tensor<int8>(x{},y[10]), f(a)(a))", "tensor<int8>(x{},y[10])"));
+ TEST_DO(verify("map_subspaces(tensor<int8>(x{},y[10]), f(a)(map_subspaces(a, f(b)(b))))", "tensor<int8>(x{},y[10])"));
+ TEST_DO(verify("map_subspaces(tensor<int8>(x{},y[10]), f(a)(map(a, f(b)(b))))", "tensor<float>(x{},y[10])"));
+ TEST_DO(verify("map_subspaces(tensor<float>(x{},y[10]), f(y)(cell_cast(y,int8)))", "tensor<int8>(x{},y[10])"));
+ TEST_DO(verify("map_subspaces(tensor<float>(x{},y[10]), f(y)(reduce(y,sum)))", "tensor<float>(x{})"));
+ TEST_DO(verify("map_subspaces(tensor<int8>(x{},y[10]), f(y)(reduce(y,sum)))", "tensor<float>(x{})"));
+ TEST_DO(verify("map_subspaces(tensor<float>(x{},y[10]), f(y)(concat(concat(y,y,y),y,y)))", "tensor<float>(x{},y[30])"));
+ TEST_DO(verify("map_subspaces(tensor<float>(x{},y[10]), f(y)(y*tensor<float>(z[5])(z+3)))", "tensor<float>(x{},y[10],z[5])"));
+
+ // error cases
+ TEST_DO(verify("map_subspaces(error, f(a)(a))", "error"));
+ TEST_DO(verify("map_subspaces(double, f(a)(tensor(x[5])(x)+tensor(x[7])(x)))", "error"));
+ TEST_DO(verify("map_subspaces(tensor<float>(x{}), f(a)(tensor(y{}):{a:3}))", "error"));
+ TEST_DO(verify("map_subspaces(tensor<float>(y[10]), f(a)(a+tensor(y[7])(y)))", "error"));
+ TEST_DO(verify("map_subspaces(tensor<float>(x{},y[10]), f(y)(y*tensor<float>(x[5])(x+3)))", "error"));
+}
+
TEST("require that double only expressions can be detected") {
auto plain_fun = Function::parse("1+2");
auto complex_fun = Function::parse("reduce(a,sum)");
diff --git a/eval/src/tests/eval/reference_evaluation/reference_evaluation_test.cpp b/eval/src/tests/eval/reference_evaluation/reference_evaluation_test.cpp
index bcb738781ad..6df1a7fdb34 100644
--- a/eval/src/tests/eval/reference_evaluation/reference_evaluation_test.cpp
+++ b/eval/src/tests/eval/reference_evaluation/reference_evaluation_test.cpp
@@ -130,6 +130,12 @@ TEST(ReferenceEvaluationTest, map_expression_works) {
EXPECT_EQ(ref_eval("map(a,f(x)(x*2+3))", {a}), expect);
}
+TEST(ReferenceEvaluationTest, map_subspaces_expression_works) {
+ auto a = make_val("tensor(x{},y[3]):{foo:[1,2,3],bar:[4,5,6]}");
+ auto expect = make_val("tensor(x{},y[2]):{foo:[3,5],bar:[9,11]}");
+ EXPECT_EQ(ref_eval("map_subspaces(a,f(x)(tensor(y[2])(x{y:(y)}+x{y:(y+1)})))", {a}), expect);
+}
+
TEST(ReferenceEvaluationTest, join_expression_works) {
auto a = make_val("tensor(x[2]):[1,2]");
auto b = make_val("tensor(y[2]):[3,4]");
diff --git a/eval/src/tests/eval/reference_operations/reference_operations_test.cpp b/eval/src/tests/eval/reference_operations/reference_operations_test.cpp
index ee876f67f34..e125a29d75f 100644
--- a/eval/src/tests/eval/reference_operations/reference_operations_test.cpp
+++ b/eval/src/tests/eval/reference_operations/reference_operations_test.cpp
@@ -49,6 +49,12 @@ TensorSpec sparse_1d_all_two() {
.add({{"c", "qux"}}, 2.0);
}
+TensorSpec spec(const vespalib::string &expr) {
+ auto result = TensorSpec::from_expr(expr);
+ EXPECT_FALSE(ValueType::from_spec(result.type()).is_error());
+ return result;
+}
+
//-----------------------------------------------------------------------------
TEST(ReferenceConcatTest, concat_numbers) {
@@ -234,6 +240,62 @@ TEST(ReferenceMapTest, map_mixed_tensor) {
//-----------------------------------------------------------------------------
+TEST(ReferenceMapSubspacesTest, map_vectors) {
+ auto input = spec("tensor(x{},y[3]):{foo:[1,2,3],bar:[4,5,6]}");
+ auto fun = [&](const TensorSpec &space) {
+ EXPECT_EQ(space.type(), "tensor(y[3])");
+ size_t i = 0;
+ double a = 0.0;
+ double b = 0.0;
+ for (const auto &[addr, value]: space.cells()) {
+ if (i < 2) {
+ a += value;
+ }
+ if (i > 0) {
+ b += value;
+ }
+ ++i;
+ }
+ TensorSpec result("tensor(y[2])");
+ result.add({{"y", 0}}, a);
+ result.add({{"y", 1}}, b);
+ return result;
+ };
+ auto output = ReferenceOperations::map_subspaces(input, fun);
+ auto expect = spec("tensor(x{},y[2]):{foo:[3,5],bar:[9,11]}");
+ EXPECT_EQ(output, expect);
+}
+
+TEST(ReferenceMapSubspacesTest, map_numbers_with_external_decay) {
+ auto input = spec("tensor<bfloat16>(x{}):{foo:3,bar:5}");
+ auto fun = [&](const TensorSpec &space) {
+ EXPECT_EQ(space.type(), "double");
+ TensorSpec result("double");
+ result.add({}, space.cells().begin()->second + 4.0);
+ return result;
+ };
+ auto output = ReferenceOperations::map_subspaces(input, fun);
+ auto expect = spec("tensor<float>(x{}):{foo:7,bar:9}");
+ EXPECT_EQ(output, expect);
+}
+
+TEST(ReferenceMapSubspacesTest, cast_cells_without_internal_decay) {
+ auto input = spec("tensor<float>(x{},y[3]):{foo:[1,2,3],bar:[4,5,6]}");
+ auto fun = [&](const TensorSpec &space) {
+ EXPECT_EQ(space.type(), "tensor<float>(y[3])");
+ TensorSpec result("tensor<bfloat16>(y[3])");
+ for (const auto &[addr, value]: space.cells()) {
+ result.add(addr, value);
+ }
+ return result;
+ };
+ auto output = ReferenceOperations::map_subspaces(input, fun);
+ auto expect = spec("tensor<bfloat16>(x{},y[3]):{foo:[1,2,3],bar:[4,5,6]}");
+ EXPECT_EQ(output, expect);
+}
+
+//-----------------------------------------------------------------------------
+
TEST(ReferenceMergeTest, simple_mixed_merge) {
auto a = mixed_5d_input(false);
auto b = TensorSpec("tensor(a[3],b[1],c{},d[5],e{})")
diff --git a/eval/src/tests/eval/value_type/value_type_test.cpp b/eval/src/tests/eval/value_type/value_type_test.cpp
index 84b9d685ee8..f472a5fb6a6 100644
--- a/eval/src/tests/eval/value_type/value_type_test.cpp
+++ b/eval/src/tests/eval/value_type/value_type_test.cpp
@@ -351,6 +351,36 @@ TEST("require that mapped dimensions can be obtained") {
TEST_DO(my_check(type("tensor(a[1],b[1],x{},y[10],z[1])").mapped_dimensions()));
}
+TEST("require that mapped dimensions can be stripped") {
+ EXPECT_EQUAL(type("error").strip_mapped_dimensions(), type("error"));
+ EXPECT_EQUAL(type("double").strip_mapped_dimensions(), type("double"));
+ EXPECT_EQUAL(type("tensor<float>(x{})").strip_mapped_dimensions(), type("double"));
+ EXPECT_EQUAL(type("tensor<float>(x[10])").strip_mapped_dimensions(), type("tensor<float>(x[10])"));
+ EXPECT_EQUAL(type("tensor<float>(a[1],b{},c[2],d{},e[3],f{})").strip_mapped_dimensions(), type("tensor<float>(a[1],c[2],e[3])"));
+}
+
+TEST("require that indexed dimensions can be stripped") {
+ EXPECT_EQUAL(type("error").strip_indexed_dimensions(), type("error"));
+ EXPECT_EQUAL(type("double").strip_indexed_dimensions(), type("double"));
+ EXPECT_EQUAL(type("tensor<float>(x{})").strip_indexed_dimensions(), type("tensor<float>(x{})"));
+ EXPECT_EQUAL(type("tensor<float>(x[10])").strip_indexed_dimensions(), type("double"));
+ EXPECT_EQUAL(type("tensor<float>(a[1],b{},c[2],d{},e[3],f{})").strip_indexed_dimensions(), type("tensor<float>(b{},d{},f{})"));
+}
+
+TEST("require that value types can be wrapped inside each other") {
+ EXPECT_EQUAL(type("error").wrap(type("error")), type("error"));
+ EXPECT_EQUAL(type("double").wrap(type("error")), type("error"));
+ EXPECT_EQUAL(type("error").wrap(type("double")), type("error"));
+ EXPECT_EQUAL(type("double").wrap(type("double")), type("double"));
+ EXPECT_EQUAL(type("tensor<int8>(x{})").wrap(type("tensor<int8>(y[10])")), type("tensor<int8>(x{},y[10])"));
+ EXPECT_EQUAL(type("tensor<int8>(a{},c{})").wrap(type("tensor<int8>(b[10],d[5])")), type("tensor<int8>(a{},b[10],c{},d[5])"));
+ EXPECT_EQUAL(type("tensor<int8>(x{})").wrap(type("tensor<int8>(x[10])")), type("error")); // dimension name conflict
+ EXPECT_EQUAL(type("tensor<int8>(x{},z[2])").wrap(type("tensor<int8>(y[10])")), type("error")); // outer cannot have indexed dimensions
+ EXPECT_EQUAL(type("tensor<int8>(x{})").wrap(type("tensor<int8>(y[10],z{})")), type("error")); // inner cannot have mapped dimensions
+ EXPECT_EQUAL(type("double").wrap(type("tensor<int8>(y[10])")), type("tensor<int8>(y[10])")); // NB: no decay
+ EXPECT_EQUAL(type("tensor<int8>(x{})").wrap(type("double")), type("tensor<float>(x{})")); // NB: decay
+}
+
TEST("require that dimension index can be obtained") {
EXPECT_EQUAL(type("error").dimension_index("x"), ValueType::Dimension::npos);
EXPECT_EQUAL(type("double").dimension_index("x"), ValueType::Dimension::npos);
diff --git a/eval/src/vespa/eval/eval/cell_type.h b/eval/src/vespa/eval/eval/cell_type.h
index b1fa29a75a5..c15a5b68dba 100644
--- a/eval/src/vespa/eval/eval/cell_type.h
+++ b/eval/src/vespa/eval/eval/cell_type.h
@@ -129,6 +129,9 @@ struct CellMeta {
// convenience functions to be used for specific operations
constexpr CellMeta map() const { return decay(); }
+ constexpr CellMeta wrap(CellMeta inner) const {
+ return (inner.is_scalar) ? decay() : inner;
+ }
constexpr CellMeta reduce(bool output_is_scalar) const {
return normalize(cell_type, output_is_scalar).decay();
}
diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp
index edcd241b6bf..1c4dcc3b5db 100644
--- a/eval/src/vespa/eval/eval/function.cpp
+++ b/eval/src/vespa/eval/eval/function.cpp
@@ -573,6 +573,13 @@ void parse_tensor_map(ParseContext &ctx) {
ctx.push_expression(std::make_unique<nodes::TensorMap>(std::move(child), std::move(lambda)));
}
+void parse_tensor_map_subspaces(ParseContext &ctx) {
+ Node_UP child = get_expression(ctx);
+ ctx.eat(',');
+ auto lambda = parse_lambda(ctx, 1);
+ ctx.push_expression(std::make_unique<nodes::TensorMapSubspaces>(std::move(child), std::move(lambda)));
+}
+
void parse_tensor_join(ParseContext &ctx) {
Node_UP lhs = get_expression(ctx);
ctx.eat(',');
@@ -856,6 +863,8 @@ bool maybe_parse_call(ParseContext &ctx, const vespalib::string &name) {
parse_call(ctx, std::move(call));
} else if (name == "map") {
parse_tensor_map(ctx);
+ } else if (name == "map_subspaces") {
+ parse_tensor_map_subspaces(ctx);
} else if (name == "join") {
parse_tensor_join(ctx);
} else if (name == "merge") {
diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp
index 2df20ac0d63..6d45aeafc26 100644
--- a/eval/src/vespa/eval/eval/key_gen.cpp
+++ b/eval/src/vespa/eval/eval/key_gen.cpp
@@ -39,64 +39,65 @@ struct KeyGen : public NodeVisitor, public NodeTraverser {
add_double(node.get_entry(i).get_const_double_value());
}
}
- void visit(const Neg &) override { add_byte( 5); }
- void visit(const Not &) override { add_byte( 6); }
- void visit(const If &node) override { add_byte( 7); add_double(node.p_true()); }
- void visit(const Error &) override { add_byte( 9); }
- void visit(const TensorMap &) override { add_byte(10); } // lambda should be part of key
- void visit(const TensorJoin &) override { add_byte(11); } // lambda should be part of key
- void visit(const TensorMerge &) override { add_byte(12); } // lambda should be part of key
- void visit(const TensorReduce &) override { add_byte(13); } // aggr/dimensions should be part of key
- void visit(const TensorRename &) override { add_byte(14); } // dimensions should be part of key
- void visit(const TensorConcat &) override { add_byte(15); } // dimension should be part of key
- void visit(const TensorCellCast &) override { add_byte(16); } // cell type should be part of key
- void visit(const TensorCreate &) override { add_byte(17); } // type/addr should be part of key
- void visit(const TensorLambda &) override { add_byte(18); } // type/lambda should be part of key
- void visit(const TensorPeek &) override { add_byte(19); } // addr should be part of key
- void visit(const Add &) override { add_byte(20); }
- void visit(const Sub &) override { add_byte(21); }
- void visit(const Mul &) override { add_byte(22); }
- void visit(const Div &) override { add_byte(23); }
- void visit(const Mod &) override { add_byte(24); }
- void visit(const Pow &) override { add_byte(25); }
- void visit(const Equal &) override { add_byte(26); }
- void visit(const NotEqual &) override { add_byte(27); }
- void visit(const Approx &) override { add_byte(28); }
- void visit(const Less &) override { add_byte(29); }
- void visit(const LessEqual &) override { add_byte(30); }
- void visit(const Greater &) override { add_byte(31); }
- void visit(const GreaterEqual &) override { add_byte(32); }
- void visit(const And &) override { add_byte(34); }
- void visit(const Or &) override { add_byte(35); }
- void visit(const Cos &) override { add_byte(36); }
- void visit(const Sin &) override { add_byte(37); }
- void visit(const Tan &) override { add_byte(38); }
- void visit(const Cosh &) override { add_byte(39); }
- void visit(const Sinh &) override { add_byte(40); }
- void visit(const Tanh &) override { add_byte(41); }
- void visit(const Acos &) override { add_byte(42); }
- void visit(const Asin &) override { add_byte(43); }
- void visit(const Atan &) override { add_byte(44); }
- void visit(const Exp &) override { add_byte(45); }
- void visit(const Log10 &) override { add_byte(46); }
- void visit(const Log &) override { add_byte(47); }
- void visit(const Sqrt &) override { add_byte(48); }
- void visit(const Ceil &) override { add_byte(49); }
- void visit(const Fabs &) override { add_byte(50); }
- void visit(const Floor &) override { add_byte(51); }
- void visit(const Atan2 &) override { add_byte(52); }
- void visit(const Ldexp &) override { add_byte(53); }
- void visit(const Pow2 &) override { add_byte(54); }
- void visit(const Fmod &) override { add_byte(55); }
- void visit(const Min &) override { add_byte(56); }
- void visit(const Max &) override { add_byte(57); }
- void visit(const IsNan &) override { add_byte(58); }
- void visit(const Relu &) override { add_byte(59); }
- void visit(const Sigmoid &) override { add_byte(60); }
- void visit(const Elu &) override { add_byte(61); }
- void visit(const Erf &) override { add_byte(62); }
- void visit(const Bit &) override { add_byte(63); }
- void visit(const Hamming &) override { add_byte(64); }
+ void visit(const Neg &) override { add_byte( 5); }
+ void visit(const Not &) override { add_byte( 6); }
+ void visit(const If &node) override { add_byte( 7); add_double(node.p_true()); }
+ void visit(const Error &) override { add_byte( 8); }
+ void visit(const TensorMap &) override { add_byte( 9); } // lambda should be part of key
+ void visit(const TensorMapSubspaces &) override { add_byte(10); } // lambda should be part of key
+ void visit(const TensorJoin &) override { add_byte(11); } // lambda should be part of key
+ void visit(const TensorMerge &) override { add_byte(12); } // lambda should be part of key
+ void visit(const TensorReduce &) override { add_byte(13); } // aggr/dimensions should be part of key
+ void visit(const TensorRename &) override { add_byte(14); } // dimensions should be part of key
+ void visit(const TensorConcat &) override { add_byte(15); } // dimension should be part of key
+ void visit(const TensorCellCast &) override { add_byte(16); } // cell type should be part of key
+ void visit(const TensorCreate &) override { add_byte(17); } // type/addr should be part of key
+ void visit(const TensorLambda &) override { add_byte(18); } // type/lambda should be part of key
+ void visit(const TensorPeek &) override { add_byte(19); } // addr should be part of key
+ void visit(const Add &) override { add_byte(20); }
+ void visit(const Sub &) override { add_byte(21); }
+ void visit(const Mul &) override { add_byte(22); }
+ void visit(const Div &) override { add_byte(23); }
+ void visit(const Mod &) override { add_byte(24); }
+ void visit(const Pow &) override { add_byte(25); }
+ void visit(const Equal &) override { add_byte(26); }
+ void visit(const NotEqual &) override { add_byte(27); }
+ void visit(const Approx &) override { add_byte(28); }
+ void visit(const Less &) override { add_byte(29); }
+ void visit(const LessEqual &) override { add_byte(30); }
+ void visit(const Greater &) override { add_byte(31); }
+ void visit(const GreaterEqual &) override { add_byte(32); }
+ void visit(const And &) override { add_byte(34); }
+ void visit(const Or &) override { add_byte(35); }
+ void visit(const Cos &) override { add_byte(36); }
+ void visit(const Sin &) override { add_byte(37); }
+ void visit(const Tan &) override { add_byte(38); }
+ void visit(const Cosh &) override { add_byte(39); }
+ void visit(const Sinh &) override { add_byte(40); }
+ void visit(const Tanh &) override { add_byte(41); }
+ void visit(const Acos &) override { add_byte(42); }
+ void visit(const Asin &) override { add_byte(43); }
+ void visit(const Atan &) override { add_byte(44); }
+ void visit(const Exp &) override { add_byte(45); }
+ void visit(const Log10 &) override { add_byte(46); }
+ void visit(const Log &) override { add_byte(47); }
+ void visit(const Sqrt &) override { add_byte(48); }
+ void visit(const Ceil &) override { add_byte(49); }
+ void visit(const Fabs &) override { add_byte(50); }
+ void visit(const Floor &) override { add_byte(51); }
+ void visit(const Atan2 &) override { add_byte(52); }
+ void visit(const Ldexp &) override { add_byte(53); }
+ void visit(const Pow2 &) override { add_byte(54); }
+ void visit(const Fmod &) override { add_byte(55); }
+ void visit(const Min &) override { add_byte(56); }
+ void visit(const Max &) override { add_byte(57); }
+ void visit(const IsNan &) override { add_byte(58); }
+ void visit(const Relu &) override { add_byte(59); }
+ void visit(const Sigmoid &) override { add_byte(60); }
+ void visit(const Elu &) override { add_byte(61); }
+ void visit(const Erf &) override { add_byte(62); }
+ void visit(const Bit &) override { add_byte(63); }
+ void visit(const Hamming &) override { add_byte(64); }
// traverse
bool open(const Node &node) override { node.accept(*this); return true; }
diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
index 5266aa64b8c..ca95d822be7 100644
--- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
+++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
@@ -456,6 +456,9 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
void visit(const TensorMap &node) override {
make_error(node.num_children());
}
+ void visit(const TensorMapSubspaces &node) override {
+ make_error(node.num_children());
+ }
void visit(const TensorJoin &node) override {
make_error(node.num_children());
}
diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp
index fe9c9704f6c..0b671a7725e 100644
--- a/eval/src/vespa/eval/eval/make_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp
@@ -51,6 +51,12 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
stack.back() = tensor_function::map(a, function, stash);
}
+ void make_map_subspaces(const TensorMapSubspaces &node) {
+ assert(stack.size() >= 1);
+ const auto &a = stack.back().get();
+ stack.back() = tensor_function::map_subspaces(a, node.lambda(), types.export_types(node.lambda().root()), stash);
+ }
+
void make_join(const Node &, operation::op2_t function) {
assert(stack.size() >= 2);
const auto &b = stack.back().get();
@@ -198,6 +204,9 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
make_map(node, token.get()->get().get_function<1>());
}
}
+ void visit(const TensorMapSubspaces &node) override {
+ make_map_subspaces(node);
+ }
void visit(const TensorJoin &node) override {
if (auto op2 = operation::lookup_op2(node.lambda())) {
make_join(node, op2.value());
diff --git a/eval/src/vespa/eval/eval/node_tools.cpp b/eval/src/vespa/eval/eval/node_tools.cpp
index 482111aba8d..20712b833ee 100644
--- a/eval/src/vespa/eval/eval/node_tools.cpp
+++ b/eval/src/vespa/eval/eval/node_tools.cpp
@@ -125,64 +125,65 @@ struct CopyNode : NodeTraverser, NodeVisitor {
}
// tensor nodes
- void visit(const TensorMap &node) override { not_implemented(node); }
- void visit(const TensorJoin &node) override { not_implemented(node); }
- void visit(const TensorMerge &node) override { not_implemented(node); }
- void visit(const TensorReduce &node) override { not_implemented(node); }
- void visit(const TensorRename &node) override { not_implemented(node); }
- void visit(const TensorConcat &node) override { not_implemented(node); }
- void visit(const TensorCellCast &node) override { not_implemented(node); }
- void visit(const TensorCreate &node) override { not_implemented(node); }
- void visit(const TensorLambda &node) override { not_implemented(node); }
- void visit(const TensorPeek &node) override { not_implemented(node); }
+ void visit(const TensorMap &node) override { not_implemented(node); }
+ void visit(const TensorMapSubspaces &node) override { not_implemented(node); }
+ void visit(const TensorJoin &node) override { not_implemented(node); }
+ void visit(const TensorMerge &node) override { not_implemented(node); }
+ void visit(const TensorReduce &node) override { not_implemented(node); }
+ void visit(const TensorRename &node) override { not_implemented(node); }
+ void visit(const TensorConcat &node) override { not_implemented(node); }
+ void visit(const TensorCellCast &node) override { not_implemented(node); }
+ void visit(const TensorCreate &node) override { not_implemented(node); }
+ void visit(const TensorLambda &node) override { not_implemented(node); }
+ void visit(const TensorPeek &node) override { not_implemented(node); }
// operator nodes
- void visit(const Add &node) override { copy_operator(node); }
- void visit(const Sub &node) override { copy_operator(node); }
- void visit(const Mul &node) override { copy_operator(node); }
- void visit(const Div &node) override { copy_operator(node); }
- void visit(const Mod &node) override { copy_operator(node); }
- void visit(const Pow &node) override { copy_operator(node); }
- void visit(const Equal &node) override { copy_operator(node); }
- void visit(const NotEqual &node) override { copy_operator(node); }
- void visit(const Approx &node) override { copy_operator(node); }
- void visit(const Less &node) override { copy_operator(node); }
- void visit(const LessEqual &node) override { copy_operator(node); }
- void visit(const Greater &node) override { copy_operator(node); }
- void visit(const GreaterEqual &node) override { copy_operator(node); }
- void visit(const And &node) override { copy_operator(node); }
- void visit(const Or &node) override { copy_operator(node); }
+ void visit(const Add &node) override { copy_operator(node); }
+ void visit(const Sub &node) override { copy_operator(node); }
+ void visit(const Mul &node) override { copy_operator(node); }
+ void visit(const Div &node) override { copy_operator(node); }
+ void visit(const Mod &node) override { copy_operator(node); }
+ void visit(const Pow &node) override { copy_operator(node); }
+ void visit(const Equal &node) override { copy_operator(node); }
+ void visit(const NotEqual &node) override { copy_operator(node); }
+ void visit(const Approx &node) override { copy_operator(node); }
+ void visit(const Less &node) override { copy_operator(node); }
+ void visit(const LessEqual &node) override { copy_operator(node); }
+ void visit(const Greater &node) override { copy_operator(node); }
+ void visit(const GreaterEqual &node) override { copy_operator(node); }
+ void visit(const And &node) override { copy_operator(node); }
+ void visit(const Or &node) override { copy_operator(node); }
// call nodes
- void visit(const Cos &node) override { copy_call(node); }
- void visit(const Sin &node) override { copy_call(node); }
- void visit(const Tan &node) override { copy_call(node); }
- void visit(const Cosh &node) override { copy_call(node); }
- void visit(const Sinh &node) override { copy_call(node); }
- void visit(const Tanh &node) override { copy_call(node); }
- void visit(const Acos &node) override { copy_call(node); }
- void visit(const Asin &node) override { copy_call(node); }
- void visit(const Atan &node) override { copy_call(node); }
- void visit(const Exp &node) override { copy_call(node); }
- void visit(const Log10 &node) override { copy_call(node); }
- void visit(const Log &node) override { copy_call(node); }
- void visit(const Sqrt &node) override { copy_call(node); }
- void visit(const Ceil &node) override { copy_call(node); }
- void visit(const Fabs &node) override { copy_call(node); }
- void visit(const Floor &node) override { copy_call(node); }
- void visit(const Atan2 &node) override { copy_call(node); }
- void visit(const Ldexp &node) override { copy_call(node); }
- void visit(const Pow2 &node) override { copy_call(node); }
- void visit(const Fmod &node) override { copy_call(node); }
- void visit(const Min &node) override { copy_call(node); }
- void visit(const Max &node) override { copy_call(node); }
- void visit(const IsNan &node) override { copy_call(node); }
- void visit(const Relu &node) override { copy_call(node); }
- void visit(const Sigmoid &node) override { copy_call(node); }
- void visit(const Elu &node) override { copy_call(node); }
- void visit(const Erf &node) override { copy_call(node); }
- void visit(const Bit &node) override { copy_call(node); }
- void visit(const Hamming &node) override { copy_call(node); }
+ void visit(const Cos &node) override { copy_call(node); }
+ void visit(const Sin &node) override { copy_call(node); }
+ void visit(const Tan &node) override { copy_call(node); }
+ void visit(const Cosh &node) override { copy_call(node); }
+ void visit(const Sinh &node) override { copy_call(node); }
+ void visit(const Tanh &node) override { copy_call(node); }
+ void visit(const Acos &node) override { copy_call(node); }
+ void visit(const Asin &node) override { copy_call(node); }
+ void visit(const Atan &node) override { copy_call(node); }
+ void visit(const Exp &node) override { copy_call(node); }
+ void visit(const Log10 &node) override { copy_call(node); }
+ void visit(const Log &node) override { copy_call(node); }
+ void visit(const Sqrt &node) override { copy_call(node); }
+ void visit(const Ceil &node) override { copy_call(node); }
+ void visit(const Fabs &node) override { copy_call(node); }
+ void visit(const Floor &node) override { copy_call(node); }
+ void visit(const Atan2 &node) override { copy_call(node); }
+ void visit(const Ldexp &node) override { copy_call(node); }
+ void visit(const Pow2 &node) override { copy_call(node); }
+ void visit(const Fmod &node) override { copy_call(node); }
+ void visit(const Min &node) override { copy_call(node); }
+ void visit(const Max &node) override { copy_call(node); }
+ void visit(const IsNan &node) override { copy_call(node); }
+ void visit(const Relu &node) override { copy_call(node); }
+ void visit(const Sigmoid &node) override { copy_call(node); }
+ void visit(const Elu &node) override { copy_call(node); }
+ void visit(const Erf &node) override { copy_call(node); }
+ void visit(const Bit &node) override { copy_call(node); }
+ void visit(const Hamming &node) override { copy_call(node); }
// traverse nodes
bool open(const Node &) override { return !error; }
diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp
index c234631984f..767d0f8b28a 100644
--- a/eval/src/vespa/eval/eval/node_types.cpp
+++ b/eval/src/vespa/eval/eval/node_types.cpp
@@ -139,6 +139,29 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser {
bind(ValueType::error_type(), node, false);
}
void visit(const TensorMap &node) override { resolve_op1(node); }
+ void visit(const TensorMapSubspaces &node) override {
+ const ValueType &in_type = type(node.child());
+ auto outer_type = in_type.strip_indexed_dimensions();
+ auto inner_type = in_type.strip_mapped_dimensions();
+ std::vector<ValueType> arg_type({inner_type});
+ NodeTypes lambda_types(node.lambda(), arg_type);
+ const ValueType &lambda_res = lambda_types.get_type(node.lambda().root());
+ if (lambda_res.is_error()) {
+ import_errors(lambda_types);
+ return fail(node, "lambda function has type errors", false);
+ }
+ if (lambda_res.count_mapped_dimensions() > 0) {
+ return fail(node, fmt("lambda function result contains mapped dimensions: %s",
+ lambda_res.to_spec().c_str()), false);
+ }
+ auto res_type = outer_type.wrap(lambda_res);
+ if (res_type.is_error()) {
+ return fail(node, fmt("lambda result contains dimensions that conflict with input type: %s <-> %s",
+ lambda_res.to_spec().c_str(), in_type.to_spec().c_str()), false);
+ }
+ import_types(lambda_types);
+ bind(res_type, node);
+ }
void visit(const TensorJoin &node) override { resolve_op2(node); }
void visit(const TensorMerge &node) override {
bind(ValueType::merge(type(node.get_child(0)),
@@ -316,6 +339,9 @@ struct TypeExporter : public NodeTraverser {
if (auto lambda = as<TensorLambda>(node)) {
lambda->lambda().root().traverse(*this);
}
+ if (auto map_subspaces = as<TensorMapSubspaces>(node)) {
+ map_subspaces->lambda().root().traverse(*this);
+ }
return true;
}
void close(const Node &node) override {
diff --git a/eval/src/vespa/eval/eval/node_visitor.h b/eval/src/vespa/eval/eval/node_visitor.h
index dcd1486824a..c57253d138c 100644
--- a/eval/src/vespa/eval/eval/node_visitor.h
+++ b/eval/src/vespa/eval/eval/node_visitor.h
@@ -18,74 +18,75 @@ namespace vespalib::eval {
struct NodeVisitor {
// basic nodes
- virtual void visit(const nodes::Number &) = 0;
- virtual void visit(const nodes::Symbol &) = 0;
- virtual void visit(const nodes::String &) = 0;
- virtual void visit(const nodes::In &) = 0;
- virtual void visit(const nodes::Neg &) = 0;
- virtual void visit(const nodes::Not &) = 0;
- virtual void visit(const nodes::If &) = 0;
- virtual void visit(const nodes::Error &) = 0;
+ virtual void visit(const nodes::Number &) = 0;
+ virtual void visit(const nodes::Symbol &) = 0;
+ virtual void visit(const nodes::String &) = 0;
+ virtual void visit(const nodes::In &) = 0;
+ virtual void visit(const nodes::Neg &) = 0;
+ virtual void visit(const nodes::Not &) = 0;
+ virtual void visit(const nodes::If &) = 0;
+ virtual void visit(const nodes::Error &) = 0;
// tensor nodes
- virtual void visit(const nodes::TensorMap &) = 0;
- virtual void visit(const nodes::TensorJoin &) = 0;
- virtual void visit(const nodes::TensorMerge &) = 0;
- virtual void visit(const nodes::TensorReduce &) = 0;
- virtual void visit(const nodes::TensorRename &) = 0;
- virtual void visit(const nodes::TensorConcat &) = 0;
- virtual void visit(const nodes::TensorCellCast &) = 0;
- virtual void visit(const nodes::TensorCreate &) = 0;
- virtual void visit(const nodes::TensorLambda &) = 0;
- virtual void visit(const nodes::TensorPeek &) = 0;
+ virtual void visit(const nodes::TensorMap &) = 0;
+ virtual void visit(const nodes::TensorMapSubspaces &) = 0;
+ virtual void visit(const nodes::TensorJoin &) = 0;
+ virtual void visit(const nodes::TensorMerge &) = 0;
+ virtual void visit(const nodes::TensorReduce &) = 0;
+ virtual void visit(const nodes::TensorRename &) = 0;
+ virtual void visit(const nodes::TensorConcat &) = 0;
+ virtual void visit(const nodes::TensorCellCast &) = 0;
+ virtual void visit(const nodes::TensorCreate &) = 0;
+ virtual void visit(const nodes::TensorLambda &) = 0;
+ virtual void visit(const nodes::TensorPeek &) = 0;
// operator nodes
- virtual void visit(const nodes::Add &) = 0;
- virtual void visit(const nodes::Sub &) = 0;
- virtual void visit(const nodes::Mul &) = 0;
- virtual void visit(const nodes::Div &) = 0;
- virtual void visit(const nodes::Mod &) = 0;
- virtual void visit(const nodes::Pow &) = 0;
- virtual void visit(const nodes::Equal &) = 0;
- virtual void visit(const nodes::NotEqual &) = 0;
- virtual void visit(const nodes::Approx &) = 0;
- virtual void visit(const nodes::Less &) = 0;
- virtual void visit(const nodes::LessEqual &) = 0;
- virtual void visit(const nodes::Greater &) = 0;
- virtual void visit(const nodes::GreaterEqual &) = 0;
- virtual void visit(const nodes::And &) = 0;
- virtual void visit(const nodes::Or &) = 0;
+ virtual void visit(const nodes::Add &) = 0;
+ virtual void visit(const nodes::Sub &) = 0;
+ virtual void visit(const nodes::Mul &) = 0;
+ virtual void visit(const nodes::Div &) = 0;
+ virtual void visit(const nodes::Mod &) = 0;
+ virtual void visit(const nodes::Pow &) = 0;
+ virtual void visit(const nodes::Equal &) = 0;
+ virtual void visit(const nodes::NotEqual &) = 0;
+ virtual void visit(const nodes::Approx &) = 0;
+ virtual void visit(const nodes::Less &) = 0;
+ virtual void visit(const nodes::LessEqual &) = 0;
+ virtual void visit(const nodes::Greater &) = 0;
+ virtual void visit(const nodes::GreaterEqual &) = 0;
+ virtual void visit(const nodes::And &) = 0;
+ virtual void visit(const nodes::Or &) = 0;
// call nodes
- virtual void visit(const nodes::Cos &) = 0;
- virtual void visit(const nodes::Sin &) = 0;
- virtual void visit(const nodes::Tan &) = 0;
- virtual void visit(const nodes::Cosh &) = 0;
- virtual void visit(const nodes::Sinh &) = 0;
- virtual void visit(const nodes::Tanh &) = 0;
- virtual void visit(const nodes::Acos &) = 0;
- virtual void visit(const nodes::Asin &) = 0;
- virtual void visit(const nodes::Atan &) = 0;
- virtual void visit(const nodes::Exp &) = 0;
- virtual void visit(const nodes::Log10 &) = 0;
- virtual void visit(const nodes::Log &) = 0;
- virtual void visit(const nodes::Sqrt &) = 0;
- virtual void visit(const nodes::Ceil &) = 0;
- virtual void visit(const nodes::Fabs &) = 0;
- virtual void visit(const nodes::Floor &) = 0;
- virtual void visit(const nodes::Atan2 &) = 0;
- virtual void visit(const nodes::Ldexp &) = 0;
- virtual void visit(const nodes::Pow2 &) = 0;
- virtual void visit(const nodes::Fmod &) = 0;
- virtual void visit(const nodes::Min &) = 0;
- virtual void visit(const nodes::Max &) = 0;
- virtual void visit(const nodes::IsNan &) = 0;
- virtual void visit(const nodes::Relu &) = 0;
- virtual void visit(const nodes::Sigmoid &) = 0;
- virtual void visit(const nodes::Elu &) = 0;
- virtual void visit(const nodes::Erf &) = 0;
- virtual void visit(const nodes::Bit &) = 0;
- virtual void visit(const nodes::Hamming &) = 0;
+ virtual void visit(const nodes::Cos &) = 0;
+ virtual void visit(const nodes::Sin &) = 0;
+ virtual void visit(const nodes::Tan &) = 0;
+ virtual void visit(const nodes::Cosh &) = 0;
+ virtual void visit(const nodes::Sinh &) = 0;
+ virtual void visit(const nodes::Tanh &) = 0;
+ virtual void visit(const nodes::Acos &) = 0;
+ virtual void visit(const nodes::Asin &) = 0;
+ virtual void visit(const nodes::Atan &) = 0;
+ virtual void visit(const nodes::Exp &) = 0;
+ virtual void visit(const nodes::Log10 &) = 0;
+ virtual void visit(const nodes::Log &) = 0;
+ virtual void visit(const nodes::Sqrt &) = 0;
+ virtual void visit(const nodes::Ceil &) = 0;
+ virtual void visit(const nodes::Fabs &) = 0;
+ virtual void visit(const nodes::Floor &) = 0;
+ virtual void visit(const nodes::Atan2 &) = 0;
+ virtual void visit(const nodes::Ldexp &) = 0;
+ virtual void visit(const nodes::Pow2 &) = 0;
+ virtual void visit(const nodes::Fmod &) = 0;
+ virtual void visit(const nodes::Min &) = 0;
+ virtual void visit(const nodes::Max &) = 0;
+ virtual void visit(const nodes::IsNan &) = 0;
+ virtual void visit(const nodes::Relu &) = 0;
+ virtual void visit(const nodes::Sigmoid &) = 0;
+ virtual void visit(const nodes::Elu &) = 0;
+ virtual void visit(const nodes::Erf &) = 0;
+ virtual void visit(const nodes::Bit &) = 0;
+ virtual void visit(const nodes::Hamming &) = 0;
virtual ~NodeVisitor() {}
};
@@ -95,68 +96,69 @@ struct NodeVisitor {
* of all types not specifically handled.
**/
struct EmptyNodeVisitor : NodeVisitor {
- void visit(const nodes::Number &) override {}
- void visit(const nodes::Symbol &) override {}
- void visit(const nodes::String &) override {}
- void visit(const nodes::In &) override {}
- void visit(const nodes::Neg &) override {}
- void visit(const nodes::Not &) override {}
- void visit(const nodes::If &) override {}
- void visit(const nodes::Error &) override {}
- void visit(const nodes::TensorMap &) override {}
- void visit(const nodes::TensorJoin &) override {}
- void visit(const nodes::TensorMerge &) override {}
- void visit(const nodes::TensorReduce &) override {}
- void visit(const nodes::TensorRename &) override {}
- void visit(const nodes::TensorConcat &) override {}
- void visit(const nodes::TensorCellCast &) override {}
- void visit(const nodes::TensorCreate &) override {}
- void visit(const nodes::TensorLambda &) override {}
- void visit(const nodes::TensorPeek &) override {}
- void visit(const nodes::Add &) override {}
- void visit(const nodes::Sub &) override {}
- void visit(const nodes::Mul &) override {}
- void visit(const nodes::Div &) override {}
- void visit(const nodes::Mod &) override {}
- void visit(const nodes::Pow &) override {}
- void visit(const nodes::Equal &) override {}
- void visit(const nodes::NotEqual &) override {}
- void visit(const nodes::Approx &) override {}
- void visit(const nodes::Less &) override {}
- void visit(const nodes::LessEqual &) override {}
- void visit(const nodes::Greater &) override {}
- void visit(const nodes::GreaterEqual &) override {}
- void visit(const nodes::And &) override {}
- void visit(const nodes::Or &) override {}
- void visit(const nodes::Cos &) override {}
- void visit(const nodes::Sin &) override {}
- void visit(const nodes::Tan &) override {}
- void visit(const nodes::Cosh &) override {}
- void visit(const nodes::Sinh &) override {}
- void visit(const nodes::Tanh &) override {}
- void visit(const nodes::Acos &) override {}
- void visit(const nodes::Asin &) override {}
- void visit(const nodes::Atan &) override {}
- void visit(const nodes::Exp &) override {}
- void visit(const nodes::Log10 &) override {}
- void visit(const nodes::Log &) override {}
- void visit(const nodes::Sqrt &) override {}
- void visit(const nodes::Ceil &) override {}
- void visit(const nodes::Fabs &) override {}
- void visit(const nodes::Floor &) override {}
- void visit(const nodes::Atan2 &) override {}
- void visit(const nodes::Ldexp &) override {}
- void visit(const nodes::Pow2 &) override {}
- void visit(const nodes::Fmod &) override {}
- void visit(const nodes::Min &) override {}
- void visit(const nodes::Max &) override {}
- void visit(const nodes::IsNan &) override {}
- void visit(const nodes::Relu &) override {}
- void visit(const nodes::Sigmoid &) override {}
- void visit(const nodes::Elu &) override {}
- void visit(const nodes::Erf &) override {}
- void visit(const nodes::Bit &) override {}
- void visit(const nodes::Hamming &) override {}
+ void visit(const nodes::Number &) override {}
+ void visit(const nodes::Symbol &) override {}
+ void visit(const nodes::String &) override {}
+ void visit(const nodes::In &) override {}
+ void visit(const nodes::Neg &) override {}
+ void visit(const nodes::Not &) override {}
+ void visit(const nodes::If &) override {}
+ void visit(const nodes::Error &) override {}
+ void visit(const nodes::TensorMap &) override {}
+ void visit(const nodes::TensorMapSubspaces &) override {}
+ void visit(const nodes::TensorJoin &) override {}
+ void visit(const nodes::TensorMerge &) override {}
+ void visit(const nodes::TensorReduce &) override {}
+ void visit(const nodes::TensorRename &) override {}
+ void visit(const nodes::TensorConcat &) override {}
+ void visit(const nodes::TensorCellCast &) override {}
+ void visit(const nodes::TensorCreate &) override {}
+ void visit(const nodes::TensorLambda &) override {}
+ void visit(const nodes::TensorPeek &) override {}
+ void visit(const nodes::Add &) override {}
+ void visit(const nodes::Sub &) override {}
+ void visit(const nodes::Mul &) override {}
+ void visit(const nodes::Div &) override {}
+ void visit(const nodes::Mod &) override {}
+ void visit(const nodes::Pow &) override {}
+ void visit(const nodes::Equal &) override {}
+ void visit(const nodes::NotEqual &) override {}
+ void visit(const nodes::Approx &) override {}
+ void visit(const nodes::Less &) override {}
+ void visit(const nodes::LessEqual &) override {}
+ void visit(const nodes::Greater &) override {}
+ void visit(const nodes::GreaterEqual &) override {}
+ void visit(const nodes::And &) override {}
+ void visit(const nodes::Or &) override {}
+ void visit(const nodes::Cos &) override {}
+ void visit(const nodes::Sin &) override {}
+ void visit(const nodes::Tan &) override {}
+ void visit(const nodes::Cosh &) override {}
+ void visit(const nodes::Sinh &) override {}
+ void visit(const nodes::Tanh &) override {}
+ void visit(const nodes::Acos &) override {}
+ void visit(const nodes::Asin &) override {}
+ void visit(const nodes::Atan &) override {}
+ void visit(const nodes::Exp &) override {}
+ void visit(const nodes::Log10 &) override {}
+ void visit(const nodes::Log &) override {}
+ void visit(const nodes::Sqrt &) override {}
+ void visit(const nodes::Ceil &) override {}
+ void visit(const nodes::Fabs &) override {}
+ void visit(const nodes::Floor &) override {}
+ void visit(const nodes::Atan2 &) override {}
+ void visit(const nodes::Ldexp &) override {}
+ void visit(const nodes::Pow2 &) override {}
+ void visit(const nodes::Fmod &) override {}
+ void visit(const nodes::Min &) override {}
+ void visit(const nodes::Max &) override {}
+ void visit(const nodes::IsNan &) override {}
+ void visit(const nodes::Relu &) override {}
+ void visit(const nodes::Sigmoid &) override {}
+ void visit(const nodes::Elu &) override {}
+ void visit(const nodes::Erf &) override {}
+ void visit(const nodes::Bit &) override {}
+ void visit(const nodes::Hamming &) override {}
};
}
diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp
index b258b6c824e..14d486aeb48 100644
--- a/eval/src/vespa/eval/eval/tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/tensor_function.cpp
@@ -12,6 +12,7 @@
#include <vespa/eval/instruction/generic_join.h>
#include <vespa/eval/instruction/generic_lambda.h>
#include <vespa/eval/instruction/generic_map.h>
+#include <vespa/eval/instruction/generic_map_subspaces.h>
#include <vespa/eval/instruction/generic_merge.h>
#include <vespa/eval/instruction/generic_peek.h>
#include <vespa/eval/instruction/generic_reduce.h>
@@ -172,6 +173,20 @@ Map::visit_self(vespalib::ObjectVisitor &visitor) const
//-----------------------------------------------------------------------------
+InterpretedFunction::Instruction
+MapSubspaces::compile_self(const ValueBuilderFactory &factory, Stash &stash) const
+{
+ return instruction::GenericMapSubspaces::make_instruction(*this, factory, stash);
+}
+
+void
+MapSubspaces::visit_self(vespalib::ObjectVisitor &visitor) const
+{
+ Super::visit_self(visitor);
+}
+
+//-----------------------------------------------------------------------------
+
Instruction
Join::compile_self(const ValueBuilderFactory &factory, Stash &stash) const
{
@@ -455,6 +470,11 @@ const TensorFunction &map(const TensorFunction &child, map_fun_t function, Stash
return stash.create<Map>(result_type, child, function);
}
+const TensorFunction &map_subspaces(const TensorFunction &child, const Function &function, NodeTypes node_types, Stash &stash) {
+ auto result_type = child.result_type().strip_indexed_dimensions().wrap(node_types.get_type(function.root()));
+ return stash.create<MapSubspaces>(result_type, child, function, std::move(node_types));
+}
+
const TensorFunction &join(const TensorFunction &lhs, const TensorFunction &rhs, join_fun_t function, Stash &stash) {
ValueType result_type = ValueType::join(lhs.result_type(), rhs.result_type());
return stash.create<Join>(result_type, lhs, rhs, function);
diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h
index 24548bfae4d..2c703fbdfef 100644
--- a/eval/src/vespa/eval/eval/tensor_function.h
+++ b/eval/src/vespa/eval/eval/tensor_function.h
@@ -249,6 +249,29 @@ public:
//-----------------------------------------------------------------------------
+class MapSubspaces : public Op1
+{
+ using Super = Op1;
+private:
+ ValueType _inner_type;
+ std::shared_ptr<Function const> _lambda;
+ NodeTypes _lambda_types;
+public:
+ MapSubspaces(const ValueType &result_type_in, const TensorFunction &child_in, const Function &lambda_in, NodeTypes lambda_types_in)
+ : Super(result_type_in, child_in),
+ _inner_type(child_in.result_type().strip_mapped_dimensions()),
+ _lambda(lambda_in.shared_from_this()),
+ _lambda_types(std::move(lambda_types_in)) {}
+ const ValueType &inner_type() const { return _inner_type; }
+ const Function &lambda() const { return *_lambda; }
+ const NodeTypes &types() const { return _lambda_types; }
+ bool result_is_mutable() const override { return true; }
+ InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const final override;
+ void visit_self(vespalib::ObjectVisitor &visitor) const override;
+};
+
+//-----------------------------------------------------------------------------
+
class Join : public Op2
{
using Super = Op2;
@@ -463,6 +486,7 @@ const TensorFunction &const_value(const Value &value, Stash &stash);
const TensorFunction &inject(const ValueType &type, size_t param_idx, Stash &stash);
const TensorFunction &reduce(const TensorFunction &child, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash);
const TensorFunction &map(const TensorFunction &child, map_fun_t function, Stash &stash);
+const TensorFunction &map_subspaces(const TensorFunction &child, const Function &function, NodeTypes node_types, Stash &stash);
const TensorFunction &join(const TensorFunction &lhs, const TensorFunction &rhs, join_fun_t function, Stash &stash);
const TensorFunction &merge(const TensorFunction &lhs, const TensorFunction &rhs, join_fun_t function, Stash &stash);
const TensorFunction &concat(const TensorFunction &lhs, const TensorFunction &rhs, const vespalib::string &dimension, Stash &stash);
diff --git a/eval/src/vespa/eval/eval/tensor_nodes.cpp b/eval/src/vespa/eval/eval/tensor_nodes.cpp
index bfcd1f979e2..ef2718234b2 100644
--- a/eval/src/vespa/eval/eval/tensor_nodes.cpp
+++ b/eval/src/vespa/eval/eval/tensor_nodes.cpp
@@ -5,15 +5,16 @@
namespace vespalib::eval::nodes {
-void TensorMap ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void TensorJoin ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void TensorMerge ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void TensorReduce ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void TensorRename ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void TensorConcat ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void TensorCellCast::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void TensorCreate ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void TensorLambda ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void TensorPeek ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorMap ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorMapSubspaces::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorJoin ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorMerge ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorReduce ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorRename ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorConcat ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorCellCast ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorCreate ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorLambda ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorPeek ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
}
diff --git a/eval/src/vespa/eval/eval/tensor_nodes.h b/eval/src/vespa/eval/eval/tensor_nodes.h
index 33f9cc5e39c..6ed19e81712 100644
--- a/eval/src/vespa/eval/eval/tensor_nodes.h
+++ b/eval/src/vespa/eval/eval/tensor_nodes.h
@@ -44,6 +44,36 @@ public:
}
};
+class TensorMapSubspaces : public Node {
+private:
+ Node_UP _child;
+ std::shared_ptr<Function const> _lambda;
+public:
+ TensorMapSubspaces(Node_UP child, std::shared_ptr<Function const> lambda)
+ : _child(std::move(child)), _lambda(std::move(lambda)) {}
+ const Node &child() const { return *_child; }
+ const Function &lambda() const { return *_lambda; }
+ vespalib::string dump(DumpContext &ctx) const override {
+ vespalib::string str;
+ str += "map_subspaces(";
+ str += _child->dump(ctx);
+ str += ",";
+ str += _lambda->dump_as_lambda();
+ str += ")";
+ return str;
+ }
+ void accept(NodeVisitor &visitor) const override;
+ size_t num_children() const override { return 1; }
+ const Node &get_child(size_t idx) const override {
+ (void) idx;
+ assert(idx == 0);
+ return *_child;
+ }
+ void detach_children(NodeHandler &handler) override {
+ handler.handle(std::move(_child));
+ }
+};
+
class TensorJoin : public Node {
private:
Node_UP _lhs;
diff --git a/eval/src/vespa/eval/eval/tensor_spec.cpp b/eval/src/vespa/eval/eval/tensor_spec.cpp
index c9401606600..323f9eaf0fe 100644
--- a/eval/src/vespa/eval/eval/tensor_spec.cpp
+++ b/eval/src/vespa/eval/eval/tensor_spec.cpp
@@ -7,6 +7,7 @@
#include "value.h"
#include "value_codec.h"
#include "value_type.h"
+#include <vespa/vespalib/util/require.h>
#include <vespa/vespalib/util/overload.h>
#include <vespa/vespalib/util/visit_ranges.h>
#include <vespa/vespalib/util/stringfmt.h>
@@ -182,19 +183,19 @@ struct NormalizeTensorSpec {
size_t dense_key = 0;
auto binding = entry.first.begin();
for (const auto &dim : type.dimensions()) {
- assert(binding != entry.first.end());
- assert(dim.name == binding->first);
- assert(dim.is_mapped() == binding->second.is_mapped());
+ REQUIRE(binding != entry.first.end());
+ REQUIRE(dim.name == binding->first);
+ REQUIRE(dim.is_mapped() == binding->second.is_mapped());
if (dim.is_mapped()) {
sparse_key.push_back(binding->second.name);
} else {
- assert(binding->second.index < dim.size);
+ REQUIRE(binding->second.index < dim.size);
dense_key = (dense_key * dim.size) + binding->second.index;
}
++binding;
}
- assert(binding == entry.first.end());
- assert(dense_key < map.values_per_entry());
+ REQUIRE(binding == entry.first.end());
+ REQUIRE(dense_key < map.values_per_entry());
auto [tag, ignore] = map.lookup_or_add_entry(ConstArrayRef<vespalib::stringref>(sparse_key));
map.get_values(tag)[dense_key] = entry.second;
}
@@ -212,7 +213,7 @@ struct NormalizeTensorSpec {
address.emplace(dim.name, *sparse_addr_iter++);
}
}
- assert(sparse_addr_iter == keys.end());
+ REQUIRE(sparse_addr_iter == keys.end());
for (size_t i = 0; i < values.size(); ++i) {
size_t dense_key = i;
for (auto dim = type.dimensions().rbegin();
@@ -364,7 +365,12 @@ TensorSpec::normalize() const
if (my_type.is_error()) {
return TensorSpec(my_type.to_spec());
}
- return typify_invoke<1,TypifyCellType,NormalizeTensorSpec>(my_type.cell_type(), my_type, *this);
+ try {
+ return typify_invoke<1,TypifyCellType,NormalizeTensorSpec>(my_type.cell_type(), my_type, *this);
+ } catch (RequireFailedException &e) {
+ fprintf(stderr, "TensorSpec::normalize: invalid spec: %s\n", to_string().c_str());
+ assert(false); // preserve crashing behavior
+ }
}
vespalib::string
diff --git a/eval/src/vespa/eval/eval/test/reference_evaluation.cpp b/eval/src/vespa/eval/eval/test/reference_evaluation.cpp
index 9bfa314493a..5a1fd2041dd 100644
--- a/eval/src/vespa/eval/eval/test/reference_evaluation.cpp
+++ b/eval/src/vespa/eval/eval/test/reference_evaluation.cpp
@@ -136,6 +136,13 @@ struct EvalNode : public NodeVisitor {
result = ReferenceOperations::peek(spec, children);
}
+ void eval_map_subspaces(const Node &node, const Node &lambda) {
+ auto fun = [&](const TensorSpec &subspace) {
+ return eval_node(lambda, {subspace});
+ };
+ result = ReferenceOperations::map_subspaces(eval_node(node, params), fun);
+ }
+
//-------------------------------------------------------------------------
void visit(const Number &node) override {
@@ -176,6 +183,9 @@ struct EvalNode : public NodeVisitor {
};
eval_map(node.child(), my_op1);
}
+ void visit(const TensorMapSubspaces &node) override {
+ eval_map_subspaces(node.child(), node.lambda().root());
+ }
void visit(const TensorJoin &node) override {
auto my_op2 = [&](double a, double b) {
return ReferenceEvaluation::eval(node.lambda(), {num(a), num(b)}).as_double();
diff --git a/eval/src/vespa/eval/eval/test/reference_operations.cpp b/eval/src/vespa/eval/eval/test/reference_operations.cpp
index 5d79f168aaa..6771eda91e3 100644
--- a/eval/src/vespa/eval/eval/test/reference_operations.cpp
+++ b/eval/src/vespa/eval/eval/test/reference_operations.cpp
@@ -176,6 +176,52 @@ TensorSpec ReferenceOperations::map(const TensorSpec &in_a, map_fun_t func) {
}
+TensorSpec ReferenceOperations::map_subspaces(const TensorSpec &a, map_subspace_fun_t fun) {
+ auto type = ValueType::from_spec(a.type());
+ auto outer_type = type.strip_indexed_dimensions();
+ auto inner_type = type.strip_mapped_dimensions();
+ auto inner_type_str = inner_type.to_spec();
+ auto lambda_res_type = ValueType::from_spec(fun(TensorSpec(inner_type_str).normalize()).type());
+ auto res_type = outer_type.wrap(lambda_res_type);
+ auto split = [](const auto &addr) {
+ TensorSpec::Address outer;
+ TensorSpec::Address inner;
+ for (const auto &[name, label]: addr) {
+ if (label.is_mapped()) {
+ outer.insert_or_assign(name, label);
+ } else {
+ inner.insert_or_assign(name, label);
+ }
+ }
+ return std::make_pair(outer, inner);
+ };
+ auto combine = [](const auto &outer, const auto &inner) {
+ TensorSpec::Address addr;
+ for (const auto &[name, label]: outer) {
+ addr.insert_or_assign(name, label);
+ }
+ for (const auto &[name, label]: inner) {
+ addr.insert_or_assign(name, label);
+ }
+ return addr;
+ };
+ std::map<TensorSpec::Address,TensorSpec> subspaces;
+ for (const auto &[addr, value]: a.cells()) {
+ auto [outer, inner] = split(addr);
+ auto &subspace = subspaces.try_emplace(outer, inner_type_str).first->second;
+ subspace.add(inner, value);
+ }
+ TensorSpec result(res_type.to_spec());
+ for (const auto &[outer, subspace]: subspaces) {
+ auto mapped = fun(subspace);
+ for (const auto &[inner, value]: mapped.cells()) {
+ result.add(combine(outer, inner), value);
+ }
+ }
+ return result.normalize();
+}
+
+
TensorSpec ReferenceOperations::merge(const TensorSpec &in_a, const TensorSpec &in_b, join_fun_t fun) {
auto a = in_a.normalize();
auto b = in_b.normalize();
diff --git a/eval/src/vespa/eval/eval/test/reference_operations.h b/eval/src/vespa/eval/eval/test/reference_operations.h
index 85aa73ec958..dd9c4b143ed 100644
--- a/eval/src/vespa/eval/eval/test/reference_operations.h
+++ b/eval/src/vespa/eval/eval/test/reference_operations.h
@@ -19,6 +19,7 @@ struct ReferenceOperations {
using map_fun_t = std::function<double(double)>;
using join_fun_t = std::function<double(double,double)>;
using lambda_fun_t = std::function<double(const std::vector<size_t> &dimension_indexes)>;
+ using map_subspace_fun_t = std::function<TensorSpec(const TensorSpec &subspace)>;
// mapping from cell address to index of child that computes the cell value
using CreateSpec = tensor_function::Create::Spec;
@@ -33,6 +34,7 @@ struct ReferenceOperations {
static TensorSpec create(const vespalib::string &type, const CreateSpec &spec, const std::vector<TensorSpec> &children);
static TensorSpec join(const TensorSpec &a, const TensorSpec &b, join_fun_t function);
static TensorSpec map(const TensorSpec &a, map_fun_t func);
+ static TensorSpec map_subspaces(const TensorSpec &a, map_subspace_fun_t fun);
static TensorSpec merge(const TensorSpec &a, const TensorSpec &b, join_fun_t fun);
static TensorSpec peek(const PeekSpec &spec, const std::vector<TensorSpec> &children);
static TensorSpec reduce(const TensorSpec &a, Aggr aggr, const std::vector<vespalib::string> &dims);
diff --git a/eval/src/vespa/eval/eval/value_type.cpp b/eval/src/vespa/eval/eval/value_type.cpp
index 1a83de9b0f9..fe70622de4e 100644
--- a/eval/src/vespa/eval/eval/value_type.cpp
+++ b/eval/src/vespa/eval/eval/value_type.cpp
@@ -138,6 +138,25 @@ struct Renamer {
bool matched_all() const { return (match_cnt == from.size()); }
};
+auto filter(const std::vector<Dimension> &dims, auto keep) {
+ std::vector<Dimension> result;
+ result.reserve(dims.size());
+ for (const auto &dim: dims) {
+ if (keep(dim)) {
+ result.push_back(dim);
+ }
+ }
+ return result;
+}
+
+auto strip(CellType old_cell_type, const std::vector<Dimension> &old_dims, auto discard) {
+ auto new_dims = filter(old_dims, [discard](const auto &dim){ return !discard(dim); });
+ if (new_dims.empty()) {
+ return ValueType::double_type();
+ }
+ return ValueType::make_type(old_cell_type, std::move(new_dims));
+}
+
} // namespace vespalib::eval::<unnamed>
constexpr ValueType::Dimension::size_type ValueType::Dimension::npos;
@@ -245,37 +264,19 @@ ValueType::dense_subspace_size() const
std::vector<ValueType::Dimension>
ValueType::nontrivial_indexed_dimensions() const
{
- std::vector<ValueType::Dimension> result;
- for (const auto &dim: dimensions()) {
- if (dim.is_indexed() && !dim.is_trivial()) {
- result.push_back(dim);
- }
- }
- return result;
+ return filter(_dimensions, [](const auto &dim){ return !dim.is_trivial() && dim.is_indexed(); });
}
std::vector<ValueType::Dimension>
ValueType::indexed_dimensions() const
{
- std::vector<ValueType::Dimension> result;
- for (const auto &dim: dimensions()) {
- if (dim.is_indexed()) {
- result.push_back(dim);
- }
- }
- return result;
+ return filter(_dimensions, [](const auto &dim){ return dim.is_indexed(); });
}
std::vector<ValueType::Dimension>
ValueType::mapped_dimensions() const
{
- std::vector<ValueType::Dimension> result;
- for (const auto &dim: dimensions()) {
- if (dim.is_mapped()) {
- result.push_back(dim);
- }
- }
- return result;
+ return filter(_dimensions, [](const auto &dim){ return dim.is_mapped(); });
}
size_t
@@ -312,6 +313,31 @@ ValueType::dimension_names() const
}
ValueType
+ValueType::strip_mapped_dimensions() const
+{
+ return error_if(_error, strip(_cell_type, _dimensions,
+ [](const auto &dim){ return dim.is_mapped(); }));
+}
+
+ValueType
+ValueType::strip_indexed_dimensions() const
+{
+ return error_if(_error, strip(_cell_type, _dimensions,
+ [](const auto &dim){ return dim.is_indexed(); }));
+}
+
+ValueType
+ValueType::wrap(const ValueType &inner)
+{
+ MyJoin result(_dimensions, inner._dimensions);
+ auto meta = cell_meta().wrap(inner.cell_meta());
+ return error_if(_error || inner._error || result.mismatch ||
+ (count_indexed_dimensions() > 0) ||
+ (inner.count_mapped_dimensions() > 0),
+ make_type(meta.cell_type, std::move(result.dimensions)));
+}
+
+ValueType
ValueType::map() const
{
auto meta = cell_meta().map();
diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h
index b35e23ee4e6..a65dde398c2 100644
--- a/eval/src/vespa/eval/eval/value_type.h
+++ b/eval/src/vespa/eval/eval/value_type.h
@@ -81,6 +81,9 @@ public:
}
bool operator!=(const ValueType &rhs) const noexcept { return !(*this == rhs); }
+ ValueType strip_mapped_dimensions() const;
+ ValueType strip_indexed_dimensions() const;
+ ValueType wrap(const ValueType &inner);
ValueType map() const;
ValueType reduce(const std::vector<vespalib::string> &dimensions_in) const;
ValueType peek(const std::vector<vespalib::string> &dimensions_in) const;
diff --git a/eval/src/vespa/eval/instruction/CMakeLists.txt b/eval/src/vespa/eval/instruction/CMakeLists.txt
index 22fa58a08fc..67203b6c7c8 100644
--- a/eval/src/vespa/eval/instruction/CMakeLists.txt
+++ b/eval/src/vespa/eval/instruction/CMakeLists.txt
@@ -24,6 +24,7 @@ vespa_add_library(eval_instruction OBJECT
generic_join.cpp
generic_lambda.cpp
generic_map.cpp
+ generic_map_subspaces.cpp
generic_merge.cpp
generic_peek.cpp
generic_reduce.cpp
diff --git a/eval/src/vespa/eval/instruction/generic_map_subspaces.cpp b/eval/src/vespa/eval/instruction/generic_map_subspaces.cpp
new file mode 100644
index 00000000000..1238d4f4e57
--- /dev/null
+++ b/eval/src/vespa/eval/instruction/generic_map_subspaces.cpp
@@ -0,0 +1,118 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "generic_map_subspaces.h"
+
+using namespace vespalib::eval::tensor_function;
+
+namespace vespalib::eval::instruction {
+
+using Instruction = InterpretedFunction::Instruction;
+using State = InterpretedFunction::State;
+
+namespace {
+
+//-----------------------------------------------------------------------------
+
+struct InterpretedParams {
+ const ValueType &result_type;
+ const ValueType &inner_type;
+ InterpretedFunction fun;
+ size_t in_size;
+ size_t out_size;
+ bool direct_in;
+ bool direct_out;
+ InterpretedParams(const MapSubspaces &map_subspaces, const ValueBuilderFactory &factory)
+ : result_type(map_subspaces.result_type()),
+ inner_type(map_subspaces.inner_type()),
+ fun(factory, map_subspaces.lambda().root(), map_subspaces.types()),
+ in_size(inner_type.dense_subspace_size()),
+ out_size(result_type.dense_subspace_size()),
+ direct_in(map_subspaces.child().result_type().cell_type() == inner_type.cell_type()),
+ direct_out(map_subspaces.types().get_type(map_subspaces.lambda().root()).cell_type() == result_type.cell_type())
+ {
+ assert(direct_in || (in_size == 1));
+ assert(direct_out || (out_size == 1));
+ }
+};
+
+struct ParamView final : Value, LazyParams {
+ const ValueType &my_type;
+ TypedCells my_cells;
+ double value;
+ bool direct;
+public:
+ ParamView(const ValueType &type_in, bool direct_in)
+ : my_type(type_in), my_cells(), value(0.0), direct(direct_in) {}
+ const ValueType &type() const final override { return my_type; }
+ template <typename ICT>
+ void adjust(const ICT *cells, size_t size) {
+ if (direct) {
+ my_cells = TypedCells(cells, get_cell_type<ICT>(), size);
+ } else {
+ value = cells[0];
+ my_cells = TypedCells(&value, CellType::DOUBLE, 1);
+ }
+ }
+ TypedCells cells() const final override { return my_cells; }
+ const Index &index() const final override { return TrivialIndex::get(); }
+ MemoryUsage get_memory_usage() const final override { return self_memory_usage<ParamView>(); }
+ const Value &resolve(size_t, Stash &) const final override { return *this; }
+};
+
+template <typename OCT>
+struct ResultFiller {
+ OCT *dst;
+ bool direct;
+public:
+ ResultFiller(OCT *dst_in, bool direct_out)
+ : dst(dst_in), direct(direct_out) {}
+ void fill(const Value &value) {
+ if (direct) {
+ auto cells = value.cells();
+ memcpy(dst, cells.data, sizeof(OCT) * cells.size);
+ dst += cells.size;
+ } else {
+ *dst++ = value.as_double();
+ }
+ }
+};
+
+template <typename ICT, typename OCT>
+void my_generic_map_subspaces_op(InterpretedFunction::State &state, uint64_t param) {
+ const InterpretedParams &params = unwrap_param<InterpretedParams>(param);
+ InterpretedFunction::Context ctx(params.fun);
+ const Value &input = state.peek(0);
+ const ICT *src = input.cells().typify<ICT>().data();
+ size_t num_subspaces = input.index().size();
+ auto res_cells = state.stash.create_uninitialized_array<OCT>(num_subspaces * params.out_size);
+ ResultFiller result_filler(res_cells.data(), params.direct_out);
+ ParamView param_view(params.inner_type, params.direct_in);
+ for (size_t i = 0; i < num_subspaces; ++i) {
+ param_view.adjust(src, params.in_size);
+ src += params.in_size;
+ result_filler.fill(params.fun.eval(ctx, param_view));
+ }
+ state.pop_push(state.stash.create<ValueView>(params.result_type, input.index(), TypedCells(res_cells)));
+}
+
+struct SelectGenericMapSubspacesOp {
+ template <typename ICT, typename OCT> static auto invoke() {
+ return my_generic_map_subspaces_op<ICT,OCT>;
+ }
+};
+
+//-----------------------------------------------------------------------------
+
+} // namespace <unnamed>
+
+Instruction
+GenericMapSubspaces::make_instruction(const tensor_function::MapSubspaces &map_subspaces_in,
+ const ValueBuilderFactory &factory, Stash &stash)
+{
+ InterpretedParams &params = stash.create<InterpretedParams>(map_subspaces_in, factory);
+ auto op = typify_invoke<2,TypifyCellType,SelectGenericMapSubspacesOp>(map_subspaces_in.child().result_type().cell_type(),
+ params.result_type.cell_type());
+ return Instruction(op, wrap_param<InterpretedParams>(params));
+}
+
+} // namespace
diff --git a/eval/src/vespa/eval/instruction/generic_map_subspaces.h b/eval/src/vespa/eval/instruction/generic_map_subspaces.h
new file mode 100644
index 00000000000..f95ded60a1b
--- /dev/null
+++ b/eval/src/vespa/eval/instruction/generic_map_subspaces.h
@@ -0,0 +1,17 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <vespa/eval/eval/interpreted_function.h>
+#include <vespa/eval/eval/tensor_function.h>
+#include <vespa/eval/eval/value_type.h>
+
+namespace vespalib::eval::instruction {
+
+struct GenericMapSubspaces {
+ static InterpretedFunction::Instruction
+ make_instruction(const tensor_function::MapSubspaces &map_subspaces_in,
+ const ValueBuilderFactory &factory, Stash &stash);
+};
+
+} // namespace