diff options
author | Håvard Pettersen <havardpe@yahooinc.com> | 2023-10-26 08:34:13 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@yahooinc.com> | 2023-11-01 15:20:13 +0000 |
commit | c70562eb766a5205fba797f62456652919e7cd3d (patch) | |
tree | ec0b81cc3ff37c489db3931743e0cc16b0c2cd66 /eval | |
parent | e8c83ca7ee15b11768e9a9b50aea479103139277 (diff) |
map_subspaces operation
Diffstat (limited to 'eval')
28 files changed, 879 insertions, 275 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt index fac6691bbff..11d5ecddaf5 100644 --- a/eval/CMakeLists.txt +++ b/eval/CMakeLists.txt @@ -29,6 +29,7 @@ vespa_define_module( src/tests/eval/int8float src/tests/eval/interpreted_function src/tests/eval/llvm_stress + src/tests/eval/map_subspaces src/tests/eval/multiply_add src/tests/eval/nested_loop src/tests/eval/node_tools diff --git a/eval/src/tests/eval/map_subspaces/CMakeLists.txt b/eval/src/tests/eval/map_subspaces/CMakeLists.txt new file mode 100644 index 00000000000..90b2ce07791 --- /dev/null +++ b/eval/src/tests/eval/map_subspaces/CMakeLists.txt @@ -0,0 +1,8 @@ +# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(eval_map_subspaces_test_app TEST + SOURCES + map_subspaces_test.cpp + DEPENDS + vespaeval +) +vespa_add_test(NAME eval_map_subspaces_test_app COMMAND eval_map_subspaces_test_app) diff --git a/eval/src/tests/eval/map_subspaces/map_subspaces_test.cpp b/eval/src/tests/eval/map_subspaces/map_subspaces_test.cpp new file mode 100644 index 00000000000..278d49992be --- /dev/null +++ b/eval/src/tests/eval/map_subspaces/map_subspaces_test.cpp @@ -0,0 +1,103 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/testkit/test_kit.h> +#include <vespa/eval/eval/tensor_function.h> +#include <vespa/eval/eval/simple_value.h> +#include <vespa/eval/eval/fast_value.h> +#include <vespa/eval/eval/test/gen_spec.h> +#include <vespa/eval/eval/test/eval_fixture.h> +#include <vespa/eval/eval/tensor_nodes.h> + +#include <vespa/vespalib/util/stringfmt.h> +#include <vespa/vespalib/util/stash.h> + +using namespace vespalib; +using namespace vespalib::eval; +using namespace vespalib::eval::test; +using namespace vespalib::eval::tensor_function; + +void verify(const vespalib::string &a, const vespalib::string &expr, const vespalib::string &result) { + EvalFixture::ParamRepo param_repo; + param_repo.add("a", TensorSpec::from_expr(a)); + auto expect = TensorSpec::from_expr(result); + EXPECT_FALSE(ValueType::from_spec(expect.type()).is_error()); + EXPECT_EQUAL(EvalFixture::ref(expr, param_repo), expect); + EXPECT_EQUAL(EvalFixture::prod(expr, param_repo), expect); +} + +//----------------------------------------------------------------------------- + +TEST("require that simple map_subspaces work") { + TEST_DO(verify("tensor(x{},y[3]):{foo:[1,2,3],bar:[4,5,6]}", + "map_subspaces(a,f(t)(tensor(y[2])(t{y:(y)}+t{y:(y+1)})))", + "tensor(x{},y[2]):{foo:[3,5],bar:[9,11]}")); +} + +TEST("require that scalars can be used with map_subspaces") { + TEST_DO(verify("3.0", + "map_subspaces(a,f(n)(n+5.0))", + "8.0")); +} + +TEST("require that outer cell type is decayed when inner type is double") { + TEST_DO(verify("tensor<int8>(x{}):{foo:3,bar:7}", + "map_subspaces(a,f(n)(n+2))", + "tensor<float>(x{}):{foo:5,bar:9}")); +} + +TEST("require that inner cell type is used directly without decay") { + TEST_DO(verify("tensor(x{},y[3]):{foo:[1,2,3],bar:[4,5,6]}", + "map_subspaces(a,f(t)(cell_cast(t,int8)))", + "tensor<int8>(x{},y[3]):{foo:[1,2,3],bar:[4,5,6]}")); + TEST_DO(verify("tensor(y[3]):[1,2,3]", + "map_subspaces(a,f(t)(cell_cast(t,int8)))", + "tensor<int8>(y[3]):[1,2,3]")); +} + +TEST("require that map_subspaces can be nested") { + TEST_DO(verify("tensor(x{},y[3]):{foo:[1,2,3],bar:[4,5,6]}", + "map_subspaces(a,f(a)(5+map_subspaces(a,f(t)(tensor(y[2])(t{y:(y)}+t{y:(y+1)})))))", + "tensor(x{},y[2]):{foo:[8,10],bar:[14,16]}")); +} + +size_t count_nodes(const NodeTypes &types) { + size_t cnt = 0; + types.each([&](const auto &, const auto &){++cnt;}); + return cnt; +} + +void check_errors(const NodeTypes &types) { + for (const auto &err: types.errors()) { + fprintf(stderr, "%s\n", err.c_str()); + } + ASSERT_EQUAL(types.errors().size(), 0u); +} + +TEST("require that type resolving also include nodes from the mapping lambda function") { + auto fun = Function::parse("map_subspaces(a,f(a)(map_subspaces(a,f(t)(tensor(y[2])(t{y:(y)}+t{y:(y+1)})))))"); + NodeTypes types(*fun, {ValueType::from_spec("tensor(x{},y[3])")}); + check_errors(types); + auto map_subspaces = nodes::as<nodes::TensorMapSubspaces>(fun->root()); + ASSERT_TRUE(map_subspaces != nullptr); + EXPECT_EQUAL(types.get_type(*map_subspaces).to_spec(), "tensor(x{},y[2])"); + EXPECT_EQUAL(types.get_type(map_subspaces->lambda().root()).to_spec(), "tensor(y[2])"); + + NodeTypes copy = types.export_types(fun->root()); + check_errors(copy); + EXPECT_EQUAL(count_nodes(types), count_nodes(copy)); + + NodeTypes map_types = copy.export_types(map_subspaces->lambda().root()); + check_errors(map_types); + EXPECT_LESS(count_nodes(map_types), count_nodes(copy)); + + auto inner_map = nodes::as<nodes::TensorMapSubspaces>(map_subspaces->lambda().root()); + ASSERT_TRUE(inner_map != nullptr); + NodeTypes inner_types = map_types.export_types(inner_map->lambda().root()); + check_errors(inner_types); + EXPECT_LESS(count_nodes(inner_types), count_nodes(map_types)); + + // [lambda, peek, t, y, +, peek, t, y, +, 1] are the 10 nodes: + EXPECT_EQUAL(count_nodes(inner_types), 10u); +} + +TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp index d58f9fda943..ad5426e2a99 100644 --- a/eval/src/tests/eval/node_types/node_types_test.cpp +++ b/eval/src/tests/eval/node_types/node_types_test.cpp @@ -354,6 +354,42 @@ TEST("require that tensor cell_cast resolves correct type") { TEST_DO(verify("cell_cast(tensor<float>(x{},y[5]),int8)", "tensor<int8>(x{},y[5])")); } +TEST("require that tensor map_subspace resolves correct type") { + // double input + TEST_DO(verify("map_subspaces(double, f(a)(a))", "double")); + TEST_DO(verify("map_subspaces(double, f(a)(tensor<int8>(y[2]):[a,a]))", "tensor<int8>(y[2])")); + + // sparse input + TEST_DO(verify("map_subspaces(tensor<float>(x{}), f(a)(a))", "tensor<float>(x{})")); + TEST_DO(verify("map_subspaces(tensor<int8>(x{}), f(a)(a))", "tensor<float>(x{})")); // NB: decay + TEST_DO(verify("map_subspaces(tensor<float>(x{}), f(a)(tensor<int8>(y[2]):[a,a]))", "tensor<int8>(x{},y[2])")); + + // dense input + TEST_DO(verify("map_subspaces(tensor<float>(y[10]), f(a)(a))", "tensor<float>(y[10])")); + TEST_DO(verify("map_subspaces(tensor<int8>(y[10]), f(a)(a))", "tensor<int8>(y[10])")); // NB: no decay + TEST_DO(verify("map_subspaces(tensor<float>(y[10]), f(a)(reduce(a,sum)))", "double")); + TEST_DO(verify("map_subspaces(tensor<float>(y[10]), f(a)(cell_cast(a,int8)))", "tensor<int8>(y[10])")); + TEST_DO(verify("map_subspaces(tensor<int8>(y[10]), f(a)(a*tensor<int8>(z[2]):[a{y:0},a{y:1}]))", "tensor<float>(y[10],z[2])")); + + // mixed input + TEST_DO(verify("map_subspaces(tensor<float>(x{},y[10]), f(a)(a))", "tensor<float>(x{},y[10])")); + TEST_DO(verify("map_subspaces(tensor<int8>(x{},y[10]), f(a)(a))", "tensor<int8>(x{},y[10])")); + TEST_DO(verify("map_subspaces(tensor<int8>(x{},y[10]), f(a)(map_subspaces(a, f(b)(b))))", "tensor<int8>(x{},y[10])")); + TEST_DO(verify("map_subspaces(tensor<int8>(x{},y[10]), f(a)(map(a, f(b)(b))))", "tensor<float>(x{},y[10])")); + TEST_DO(verify("map_subspaces(tensor<float>(x{},y[10]), f(y)(cell_cast(y,int8)))", "tensor<int8>(x{},y[10])")); + TEST_DO(verify("map_subspaces(tensor<float>(x{},y[10]), f(y)(reduce(y,sum)))", "tensor<float>(x{})")); + TEST_DO(verify("map_subspaces(tensor<int8>(x{},y[10]), f(y)(reduce(y,sum)))", "tensor<float>(x{})")); + TEST_DO(verify("map_subspaces(tensor<float>(x{},y[10]), f(y)(concat(concat(y,y,y),y,y)))", "tensor<float>(x{},y[30])")); + TEST_DO(verify("map_subspaces(tensor<float>(x{},y[10]), f(y)(y*tensor<float>(z[5])(z+3)))", "tensor<float>(x{},y[10],z[5])")); + + // error cases + TEST_DO(verify("map_subspaces(error, f(a)(a))", "error")); + TEST_DO(verify("map_subspaces(double, f(a)(tensor(x[5])(x)+tensor(x[7])(x)))", "error")); + TEST_DO(verify("map_subspaces(tensor<float>(x{}), f(a)(tensor(y{}):{a:3}))", "error")); + TEST_DO(verify("map_subspaces(tensor<float>(y[10]), f(a)(a+tensor(y[7])(y)))", "error")); + TEST_DO(verify("map_subspaces(tensor<float>(x{},y[10]), f(y)(y*tensor<float>(x[5])(x+3)))", "error")); +} + TEST("require that double only expressions can be detected") { auto plain_fun = Function::parse("1+2"); auto complex_fun = Function::parse("reduce(a,sum)"); diff --git a/eval/src/tests/eval/reference_evaluation/reference_evaluation_test.cpp b/eval/src/tests/eval/reference_evaluation/reference_evaluation_test.cpp index bcb738781ad..6df1a7fdb34 100644 --- a/eval/src/tests/eval/reference_evaluation/reference_evaluation_test.cpp +++ b/eval/src/tests/eval/reference_evaluation/reference_evaluation_test.cpp @@ -130,6 +130,12 @@ TEST(ReferenceEvaluationTest, map_expression_works) { EXPECT_EQ(ref_eval("map(a,f(x)(x*2+3))", {a}), expect); } +TEST(ReferenceEvaluationTest, map_subspaces_expression_works) { + auto a = make_val("tensor(x{},y[3]):{foo:[1,2,3],bar:[4,5,6]}"); + auto expect = make_val("tensor(x{},y[2]):{foo:[3,5],bar:[9,11]}"); + EXPECT_EQ(ref_eval("map_subspaces(a,f(x)(tensor(y[2])(x{y:(y)}+x{y:(y+1)})))", {a}), expect); +} + TEST(ReferenceEvaluationTest, join_expression_works) { auto a = make_val("tensor(x[2]):[1,2]"); auto b = make_val("tensor(y[2]):[3,4]"); diff --git a/eval/src/tests/eval/reference_operations/reference_operations_test.cpp b/eval/src/tests/eval/reference_operations/reference_operations_test.cpp index ee876f67f34..e125a29d75f 100644 --- a/eval/src/tests/eval/reference_operations/reference_operations_test.cpp +++ b/eval/src/tests/eval/reference_operations/reference_operations_test.cpp @@ -49,6 +49,12 @@ TensorSpec sparse_1d_all_two() { .add({{"c", "qux"}}, 2.0); } +TensorSpec spec(const vespalib::string &expr) { + auto result = TensorSpec::from_expr(expr); + EXPECT_FALSE(ValueType::from_spec(result.type()).is_error()); + return result; +} + //----------------------------------------------------------------------------- TEST(ReferenceConcatTest, concat_numbers) { @@ -234,6 +240,62 @@ TEST(ReferenceMapTest, map_mixed_tensor) { //----------------------------------------------------------------------------- +TEST(ReferenceMapSubspacesTest, map_vectors) { + auto input = spec("tensor(x{},y[3]):{foo:[1,2,3],bar:[4,5,6]}"); + auto fun = [&](const TensorSpec &space) { + EXPECT_EQ(space.type(), "tensor(y[3])"); + size_t i = 0; + double a = 0.0; + double b = 0.0; + for (const auto &[addr, value]: space.cells()) { + if (i < 2) { + a += value; + } + if (i > 0) { + b += value; + } + ++i; + } + TensorSpec result("tensor(y[2])"); + result.add({{"y", 0}}, a); + result.add({{"y", 1}}, b); + return result; + }; + auto output = ReferenceOperations::map_subspaces(input, fun); + auto expect = spec("tensor(x{},y[2]):{foo:[3,5],bar:[9,11]}"); + EXPECT_EQ(output, expect); +} + +TEST(ReferenceMapSubspacesTest, map_numbers_with_external_decay) { + auto input = spec("tensor<bfloat16>(x{}):{foo:3,bar:5}"); + auto fun = [&](const TensorSpec &space) { + EXPECT_EQ(space.type(), "double"); + TensorSpec result("double"); + result.add({}, space.cells().begin()->second + 4.0); + return result; + }; + auto output = ReferenceOperations::map_subspaces(input, fun); + auto expect = spec("tensor<float>(x{}):{foo:7,bar:9}"); + EXPECT_EQ(output, expect); +} + +TEST(ReferenceMapSubspacesTest, cast_cells_without_internal_decay) { + auto input = spec("tensor<float>(x{},y[3]):{foo:[1,2,3],bar:[4,5,6]}"); + auto fun = [&](const TensorSpec &space) { + EXPECT_EQ(space.type(), "tensor<float>(y[3])"); + TensorSpec result("tensor<bfloat16>(y[3])"); + for (const auto &[addr, value]: space.cells()) { + result.add(addr, value); + } + return result; + }; + auto output = ReferenceOperations::map_subspaces(input, fun); + auto expect = spec("tensor<bfloat16>(x{},y[3]):{foo:[1,2,3],bar:[4,5,6]}"); + EXPECT_EQ(output, expect); +} + +//----------------------------------------------------------------------------- + TEST(ReferenceMergeTest, simple_mixed_merge) { auto a = mixed_5d_input(false); auto b = TensorSpec("tensor(a[3],b[1],c{},d[5],e{})") diff --git a/eval/src/tests/eval/value_type/value_type_test.cpp b/eval/src/tests/eval/value_type/value_type_test.cpp index 84b9d685ee8..f472a5fb6a6 100644 --- a/eval/src/tests/eval/value_type/value_type_test.cpp +++ b/eval/src/tests/eval/value_type/value_type_test.cpp @@ -351,6 +351,36 @@ TEST("require that mapped dimensions can be obtained") { TEST_DO(my_check(type("tensor(a[1],b[1],x{},y[10],z[1])").mapped_dimensions())); } +TEST("require that mapped dimensions can be stripped") { + EXPECT_EQUAL(type("error").strip_mapped_dimensions(), type("error")); + EXPECT_EQUAL(type("double").strip_mapped_dimensions(), type("double")); + EXPECT_EQUAL(type("tensor<float>(x{})").strip_mapped_dimensions(), type("double")); + EXPECT_EQUAL(type("tensor<float>(x[10])").strip_mapped_dimensions(), type("tensor<float>(x[10])")); + EXPECT_EQUAL(type("tensor<float>(a[1],b{},c[2],d{},e[3],f{})").strip_mapped_dimensions(), type("tensor<float>(a[1],c[2],e[3])")); +} + +TEST("require that indexed dimensions can be stripped") { + EXPECT_EQUAL(type("error").strip_indexed_dimensions(), type("error")); + EXPECT_EQUAL(type("double").strip_indexed_dimensions(), type("double")); + EXPECT_EQUAL(type("tensor<float>(x{})").strip_indexed_dimensions(), type("tensor<float>(x{})")); + EXPECT_EQUAL(type("tensor<float>(x[10])").strip_indexed_dimensions(), type("double")); + EXPECT_EQUAL(type("tensor<float>(a[1],b{},c[2],d{},e[3],f{})").strip_indexed_dimensions(), type("tensor<float>(b{},d{},f{})")); +} + +TEST("require that value types can be wrapped inside each other") { + EXPECT_EQUAL(type("error").wrap(type("error")), type("error")); + EXPECT_EQUAL(type("double").wrap(type("error")), type("error")); + EXPECT_EQUAL(type("error").wrap(type("double")), type("error")); + EXPECT_EQUAL(type("double").wrap(type("double")), type("double")); + EXPECT_EQUAL(type("tensor<int8>(x{})").wrap(type("tensor<int8>(y[10])")), type("tensor<int8>(x{},y[10])")); + EXPECT_EQUAL(type("tensor<int8>(a{},c{})").wrap(type("tensor<int8>(b[10],d[5])")), type("tensor<int8>(a{},b[10],c{},d[5])")); + EXPECT_EQUAL(type("tensor<int8>(x{})").wrap(type("tensor<int8>(x[10])")), type("error")); // dimension name conflict + EXPECT_EQUAL(type("tensor<int8>(x{},z[2])").wrap(type("tensor<int8>(y[10])")), type("error")); // outer cannot have indexed dimensions + EXPECT_EQUAL(type("tensor<int8>(x{})").wrap(type("tensor<int8>(y[10],z{})")), type("error")); // inner cannot have mapped dimensions + EXPECT_EQUAL(type("double").wrap(type("tensor<int8>(y[10])")), type("tensor<int8>(y[10])")); // NB: no decay + EXPECT_EQUAL(type("tensor<int8>(x{})").wrap(type("double")), type("tensor<float>(x{})")); // NB: decay +} + TEST("require that dimension index can be obtained") { EXPECT_EQUAL(type("error").dimension_index("x"), ValueType::Dimension::npos); EXPECT_EQUAL(type("double").dimension_index("x"), ValueType::Dimension::npos); diff --git a/eval/src/vespa/eval/eval/cell_type.h b/eval/src/vespa/eval/eval/cell_type.h index b1fa29a75a5..c15a5b68dba 100644 --- a/eval/src/vespa/eval/eval/cell_type.h +++ b/eval/src/vespa/eval/eval/cell_type.h @@ -129,6 +129,9 @@ struct CellMeta { // convenience functions to be used for specific operations constexpr CellMeta map() const { return decay(); } + constexpr CellMeta wrap(CellMeta inner) const { + return (inner.is_scalar) ? decay() : inner; + } constexpr CellMeta reduce(bool output_is_scalar) const { return normalize(cell_type, output_is_scalar).decay(); } diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp index edcd241b6bf..1c4dcc3b5db 100644 --- a/eval/src/vespa/eval/eval/function.cpp +++ b/eval/src/vespa/eval/eval/function.cpp @@ -573,6 +573,13 @@ void parse_tensor_map(ParseContext &ctx) { ctx.push_expression(std::make_unique<nodes::TensorMap>(std::move(child), std::move(lambda))); } +void parse_tensor_map_subspaces(ParseContext &ctx) { + Node_UP child = get_expression(ctx); + ctx.eat(','); + auto lambda = parse_lambda(ctx, 1); + ctx.push_expression(std::make_unique<nodes::TensorMapSubspaces>(std::move(child), std::move(lambda))); +} + void parse_tensor_join(ParseContext &ctx) { Node_UP lhs = get_expression(ctx); ctx.eat(','); @@ -856,6 +863,8 @@ bool maybe_parse_call(ParseContext &ctx, const vespalib::string &name) { parse_call(ctx, std::move(call)); } else if (name == "map") { parse_tensor_map(ctx); + } else if (name == "map_subspaces") { + parse_tensor_map_subspaces(ctx); } else if (name == "join") { parse_tensor_join(ctx); } else if (name == "merge") { diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp index 2df20ac0d63..6d45aeafc26 100644 --- a/eval/src/vespa/eval/eval/key_gen.cpp +++ b/eval/src/vespa/eval/eval/key_gen.cpp @@ -39,64 +39,65 @@ struct KeyGen : public NodeVisitor, public NodeTraverser { add_double(node.get_entry(i).get_const_double_value()); } } - void visit(const Neg &) override { add_byte( 5); } - void visit(const Not &) override { add_byte( 6); } - void visit(const If &node) override { add_byte( 7); add_double(node.p_true()); } - void visit(const Error &) override { add_byte( 9); } - void visit(const TensorMap &) override { add_byte(10); } // lambda should be part of key - void visit(const TensorJoin &) override { add_byte(11); } // lambda should be part of key - void visit(const TensorMerge &) override { add_byte(12); } // lambda should be part of key - void visit(const TensorReduce &) override { add_byte(13); } // aggr/dimensions should be part of key - void visit(const TensorRename &) override { add_byte(14); } // dimensions should be part of key - void visit(const TensorConcat &) override { add_byte(15); } // dimension should be part of key - void visit(const TensorCellCast &) override { add_byte(16); } // cell type should be part of key - void visit(const TensorCreate &) override { add_byte(17); } // type/addr should be part of key - void visit(const TensorLambda &) override { add_byte(18); } // type/lambda should be part of key - void visit(const TensorPeek &) override { add_byte(19); } // addr should be part of key - void visit(const Add &) override { add_byte(20); } - void visit(const Sub &) override { add_byte(21); } - void visit(const Mul &) override { add_byte(22); } - void visit(const Div &) override { add_byte(23); } - void visit(const Mod &) override { add_byte(24); } - void visit(const Pow &) override { add_byte(25); } - void visit(const Equal &) override { add_byte(26); } - void visit(const NotEqual &) override { add_byte(27); } - void visit(const Approx &) override { add_byte(28); } - void visit(const Less &) override { add_byte(29); } - void visit(const LessEqual &) override { add_byte(30); } - void visit(const Greater &) override { add_byte(31); } - void visit(const GreaterEqual &) override { add_byte(32); } - void visit(const And &) override { add_byte(34); } - void visit(const Or &) override { add_byte(35); } - void visit(const Cos &) override { add_byte(36); } - void visit(const Sin &) override { add_byte(37); } - void visit(const Tan &) override { add_byte(38); } - void visit(const Cosh &) override { add_byte(39); } - void visit(const Sinh &) override { add_byte(40); } - void visit(const Tanh &) override { add_byte(41); } - void visit(const Acos &) override { add_byte(42); } - void visit(const Asin &) override { add_byte(43); } - void visit(const Atan &) override { add_byte(44); } - void visit(const Exp &) override { add_byte(45); } - void visit(const Log10 &) override { add_byte(46); } - void visit(const Log &) override { add_byte(47); } - void visit(const Sqrt &) override { add_byte(48); } - void visit(const Ceil &) override { add_byte(49); } - void visit(const Fabs &) override { add_byte(50); } - void visit(const Floor &) override { add_byte(51); } - void visit(const Atan2 &) override { add_byte(52); } - void visit(const Ldexp &) override { add_byte(53); } - void visit(const Pow2 &) override { add_byte(54); } - void visit(const Fmod &) override { add_byte(55); } - void visit(const Min &) override { add_byte(56); } - void visit(const Max &) override { add_byte(57); } - void visit(const IsNan &) override { add_byte(58); } - void visit(const Relu &) override { add_byte(59); } - void visit(const Sigmoid &) override { add_byte(60); } - void visit(const Elu &) override { add_byte(61); } - void visit(const Erf &) override { add_byte(62); } - void visit(const Bit &) override { add_byte(63); } - void visit(const Hamming &) override { add_byte(64); } + void visit(const Neg &) override { add_byte( 5); } + void visit(const Not &) override { add_byte( 6); } + void visit(const If &node) override { add_byte( 7); add_double(node.p_true()); } + void visit(const Error &) override { add_byte( 8); } + void visit(const TensorMap &) override { add_byte( 9); } // lambda should be part of key + void visit(const TensorMapSubspaces &) override { add_byte(10); } // lambda should be part of key + void visit(const TensorJoin &) override { add_byte(11); } // lambda should be part of key + void visit(const TensorMerge &) override { add_byte(12); } // lambda should be part of key + void visit(const TensorReduce &) override { add_byte(13); } // aggr/dimensions should be part of key + void visit(const TensorRename &) override { add_byte(14); } // dimensions should be part of key + void visit(const TensorConcat &) override { add_byte(15); } // dimension should be part of key + void visit(const TensorCellCast &) override { add_byte(16); } // cell type should be part of key + void visit(const TensorCreate &) override { add_byte(17); } // type/addr should be part of key + void visit(const TensorLambda &) override { add_byte(18); } // type/lambda should be part of key + void visit(const TensorPeek &) override { add_byte(19); } // addr should be part of key + void visit(const Add &) override { add_byte(20); } + void visit(const Sub &) override { add_byte(21); } + void visit(const Mul &) override { add_byte(22); } + void visit(const Div &) override { add_byte(23); } + void visit(const Mod &) override { add_byte(24); } + void visit(const Pow &) override { add_byte(25); } + void visit(const Equal &) override { add_byte(26); } + void visit(const NotEqual &) override { add_byte(27); } + void visit(const Approx &) override { add_byte(28); } + void visit(const Less &) override { add_byte(29); } + void visit(const LessEqual &) override { add_byte(30); } + void visit(const Greater &) override { add_byte(31); } + void visit(const GreaterEqual &) override { add_byte(32); } + void visit(const And &) override { add_byte(34); } + void visit(const Or &) override { add_byte(35); } + void visit(const Cos &) override { add_byte(36); } + void visit(const Sin &) override { add_byte(37); } + void visit(const Tan &) override { add_byte(38); } + void visit(const Cosh &) override { add_byte(39); } + void visit(const Sinh &) override { add_byte(40); } + void visit(const Tanh &) override { add_byte(41); } + void visit(const Acos &) override { add_byte(42); } + void visit(const Asin &) override { add_byte(43); } + void visit(const Atan &) override { add_byte(44); } + void visit(const Exp &) override { add_byte(45); } + void visit(const Log10 &) override { add_byte(46); } + void visit(const Log &) override { add_byte(47); } + void visit(const Sqrt &) override { add_byte(48); } + void visit(const Ceil &) override { add_byte(49); } + void visit(const Fabs &) override { add_byte(50); } + void visit(const Floor &) override { add_byte(51); } + void visit(const Atan2 &) override { add_byte(52); } + void visit(const Ldexp &) override { add_byte(53); } + void visit(const Pow2 &) override { add_byte(54); } + void visit(const Fmod &) override { add_byte(55); } + void visit(const Min &) override { add_byte(56); } + void visit(const Max &) override { add_byte(57); } + void visit(const IsNan &) override { add_byte(58); } + void visit(const Relu &) override { add_byte(59); } + void visit(const Sigmoid &) override { add_byte(60); } + void visit(const Elu &) override { add_byte(61); } + void visit(const Erf &) override { add_byte(62); } + void visit(const Bit &) override { add_byte(63); } + void visit(const Hamming &) override { add_byte(64); } // traverse bool open(const Node &node) override { node.accept(*this); return true; } diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp index 5266aa64b8c..ca95d822be7 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp @@ -456,6 +456,9 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { void visit(const TensorMap &node) override { make_error(node.num_children()); } + void visit(const TensorMapSubspaces &node) override { + make_error(node.num_children()); + } void visit(const TensorJoin &node) override { make_error(node.num_children()); } diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp index fe9c9704f6c..0b671a7725e 100644 --- a/eval/src/vespa/eval/eval/make_tensor_function.cpp +++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp @@ -51,6 +51,12 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { stack.back() = tensor_function::map(a, function, stash); } + void make_map_subspaces(const TensorMapSubspaces &node) { + assert(stack.size() >= 1); + const auto &a = stack.back().get(); + stack.back() = tensor_function::map_subspaces(a, node.lambda(), types.export_types(node.lambda().root()), stash); + } + void make_join(const Node &, operation::op2_t function) { assert(stack.size() >= 2); const auto &b = stack.back().get(); @@ -198,6 +204,9 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { make_map(node, token.get()->get().get_function<1>()); } } + void visit(const TensorMapSubspaces &node) override { + make_map_subspaces(node); + } void visit(const TensorJoin &node) override { if (auto op2 = operation::lookup_op2(node.lambda())) { make_join(node, op2.value()); diff --git a/eval/src/vespa/eval/eval/node_tools.cpp b/eval/src/vespa/eval/eval/node_tools.cpp index 482111aba8d..20712b833ee 100644 --- a/eval/src/vespa/eval/eval/node_tools.cpp +++ b/eval/src/vespa/eval/eval/node_tools.cpp @@ -125,64 +125,65 @@ struct CopyNode : NodeTraverser, NodeVisitor { } // tensor nodes - void visit(const TensorMap &node) override { not_implemented(node); } - void visit(const TensorJoin &node) override { not_implemented(node); } - void visit(const TensorMerge &node) override { not_implemented(node); } - void visit(const TensorReduce &node) override { not_implemented(node); } - void visit(const TensorRename &node) override { not_implemented(node); } - void visit(const TensorConcat &node) override { not_implemented(node); } - void visit(const TensorCellCast &node) override { not_implemented(node); } - void visit(const TensorCreate &node) override { not_implemented(node); } - void visit(const TensorLambda &node) override { not_implemented(node); } - void visit(const TensorPeek &node) override { not_implemented(node); } + void visit(const TensorMap &node) override { not_implemented(node); } + void visit(const TensorMapSubspaces &node) override { not_implemented(node); } + void visit(const TensorJoin &node) override { not_implemented(node); } + void visit(const TensorMerge &node) override { not_implemented(node); } + void visit(const TensorReduce &node) override { not_implemented(node); } + void visit(const TensorRename &node) override { not_implemented(node); } + void visit(const TensorConcat &node) override { not_implemented(node); } + void visit(const TensorCellCast &node) override { not_implemented(node); } + void visit(const TensorCreate &node) override { not_implemented(node); } + void visit(const TensorLambda &node) override { not_implemented(node); } + void visit(const TensorPeek &node) override { not_implemented(node); } // operator nodes - void visit(const Add &node) override { copy_operator(node); } - void visit(const Sub &node) override { copy_operator(node); } - void visit(const Mul &node) override { copy_operator(node); } - void visit(const Div &node) override { copy_operator(node); } - void visit(const Mod &node) override { copy_operator(node); } - void visit(const Pow &node) override { copy_operator(node); } - void visit(const Equal &node) override { copy_operator(node); } - void visit(const NotEqual &node) override { copy_operator(node); } - void visit(const Approx &node) override { copy_operator(node); } - void visit(const Less &node) override { copy_operator(node); } - void visit(const LessEqual &node) override { copy_operator(node); } - void visit(const Greater &node) override { copy_operator(node); } - void visit(const GreaterEqual &node) override { copy_operator(node); } - void visit(const And &node) override { copy_operator(node); } - void visit(const Or &node) override { copy_operator(node); } + void visit(const Add &node) override { copy_operator(node); } + void visit(const Sub &node) override { copy_operator(node); } + void visit(const Mul &node) override { copy_operator(node); } + void visit(const Div &node) override { copy_operator(node); } + void visit(const Mod &node) override { copy_operator(node); } + void visit(const Pow &node) override { copy_operator(node); } + void visit(const Equal &node) override { copy_operator(node); } + void visit(const NotEqual &node) override { copy_operator(node); } + void visit(const Approx &node) override { copy_operator(node); } + void visit(const Less &node) override { copy_operator(node); } + void visit(const LessEqual &node) override { copy_operator(node); } + void visit(const Greater &node) override { copy_operator(node); } + void visit(const GreaterEqual &node) override { copy_operator(node); } + void visit(const And &node) override { copy_operator(node); } + void visit(const Or &node) override { copy_operator(node); } // call nodes - void visit(const Cos &node) override { copy_call(node); } - void visit(const Sin &node) override { copy_call(node); } - void visit(const Tan &node) override { copy_call(node); } - void visit(const Cosh &node) override { copy_call(node); } - void visit(const Sinh &node) override { copy_call(node); } - void visit(const Tanh &node) override { copy_call(node); } - void visit(const Acos &node) override { copy_call(node); } - void visit(const Asin &node) override { copy_call(node); } - void visit(const Atan &node) override { copy_call(node); } - void visit(const Exp &node) override { copy_call(node); } - void visit(const Log10 &node) override { copy_call(node); } - void visit(const Log &node) override { copy_call(node); } - void visit(const Sqrt &node) override { copy_call(node); } - void visit(const Ceil &node) override { copy_call(node); } - void visit(const Fabs &node) override { copy_call(node); } - void visit(const Floor &node) override { copy_call(node); } - void visit(const Atan2 &node) override { copy_call(node); } - void visit(const Ldexp &node) override { copy_call(node); } - void visit(const Pow2 &node) override { copy_call(node); } - void visit(const Fmod &node) override { copy_call(node); } - void visit(const Min &node) override { copy_call(node); } - void visit(const Max &node) override { copy_call(node); } - void visit(const IsNan &node) override { copy_call(node); } - void visit(const Relu &node) override { copy_call(node); } - void visit(const Sigmoid &node) override { copy_call(node); } - void visit(const Elu &node) override { copy_call(node); } - void visit(const Erf &node) override { copy_call(node); } - void visit(const Bit &node) override { copy_call(node); } - void visit(const Hamming &node) override { copy_call(node); } + void visit(const Cos &node) override { copy_call(node); } + void visit(const Sin &node) override { copy_call(node); } + void visit(const Tan &node) override { copy_call(node); } + void visit(const Cosh &node) override { copy_call(node); } + void visit(const Sinh &node) override { copy_call(node); } + void visit(const Tanh &node) override { copy_call(node); } + void visit(const Acos &node) override { copy_call(node); } + void visit(const Asin &node) override { copy_call(node); } + void visit(const Atan &node) override { copy_call(node); } + void visit(const Exp &node) override { copy_call(node); } + void visit(const Log10 &node) override { copy_call(node); } + void visit(const Log &node) override { copy_call(node); } + void visit(const Sqrt &node) override { copy_call(node); } + void visit(const Ceil &node) override { copy_call(node); } + void visit(const Fabs &node) override { copy_call(node); } + void visit(const Floor &node) override { copy_call(node); } + void visit(const Atan2 &node) override { copy_call(node); } + void visit(const Ldexp &node) override { copy_call(node); } + void visit(const Pow2 &node) override { copy_call(node); } + void visit(const Fmod &node) override { copy_call(node); } + void visit(const Min &node) override { copy_call(node); } + void visit(const Max &node) override { copy_call(node); } + void visit(const IsNan &node) override { copy_call(node); } + void visit(const Relu &node) override { copy_call(node); } + void visit(const Sigmoid &node) override { copy_call(node); } + void visit(const Elu &node) override { copy_call(node); } + void visit(const Erf &node) override { copy_call(node); } + void visit(const Bit &node) override { copy_call(node); } + void visit(const Hamming &node) override { copy_call(node); } // traverse nodes bool open(const Node &) override { return !error; } diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp index c234631984f..767d0f8b28a 100644 --- a/eval/src/vespa/eval/eval/node_types.cpp +++ b/eval/src/vespa/eval/eval/node_types.cpp @@ -139,6 +139,29 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser { bind(ValueType::error_type(), node, false); } void visit(const TensorMap &node) override { resolve_op1(node); } + void visit(const TensorMapSubspaces &node) override { + const ValueType &in_type = type(node.child()); + auto outer_type = in_type.strip_indexed_dimensions(); + auto inner_type = in_type.strip_mapped_dimensions(); + std::vector<ValueType> arg_type({inner_type}); + NodeTypes lambda_types(node.lambda(), arg_type); + const ValueType &lambda_res = lambda_types.get_type(node.lambda().root()); + if (lambda_res.is_error()) { + import_errors(lambda_types); + return fail(node, "lambda function has type errors", false); + } + if (lambda_res.count_mapped_dimensions() > 0) { + return fail(node, fmt("lambda function result contains mapped dimensions: %s", + lambda_res.to_spec().c_str()), false); + } + auto res_type = outer_type.wrap(lambda_res); + if (res_type.is_error()) { + return fail(node, fmt("lambda result contains dimensions that conflict with input type: %s <-> %s", + lambda_res.to_spec().c_str(), in_type.to_spec().c_str()), false); + } + import_types(lambda_types); + bind(res_type, node); + } void visit(const TensorJoin &node) override { resolve_op2(node); } void visit(const TensorMerge &node) override { bind(ValueType::merge(type(node.get_child(0)), @@ -316,6 +339,9 @@ struct TypeExporter : public NodeTraverser { if (auto lambda = as<TensorLambda>(node)) { lambda->lambda().root().traverse(*this); } + if (auto map_subspaces = as<TensorMapSubspaces>(node)) { + map_subspaces->lambda().root().traverse(*this); + } return true; } void close(const Node &node) override { diff --git a/eval/src/vespa/eval/eval/node_visitor.h b/eval/src/vespa/eval/eval/node_visitor.h index dcd1486824a..c57253d138c 100644 --- a/eval/src/vespa/eval/eval/node_visitor.h +++ b/eval/src/vespa/eval/eval/node_visitor.h @@ -18,74 +18,75 @@ namespace vespalib::eval { struct NodeVisitor { // basic nodes - virtual void visit(const nodes::Number &) = 0; - virtual void visit(const nodes::Symbol &) = 0; - virtual void visit(const nodes::String &) = 0; - virtual void visit(const nodes::In &) = 0; - virtual void visit(const nodes::Neg &) = 0; - virtual void visit(const nodes::Not &) = 0; - virtual void visit(const nodes::If &) = 0; - virtual void visit(const nodes::Error &) = 0; + virtual void visit(const nodes::Number &) = 0; + virtual void visit(const nodes::Symbol &) = 0; + virtual void visit(const nodes::String &) = 0; + virtual void visit(const nodes::In &) = 0; + virtual void visit(const nodes::Neg &) = 0; + virtual void visit(const nodes::Not &) = 0; + virtual void visit(const nodes::If &) = 0; + virtual void visit(const nodes::Error &) = 0; // tensor nodes - virtual void visit(const nodes::TensorMap &) = 0; - virtual void visit(const nodes::TensorJoin &) = 0; - virtual void visit(const nodes::TensorMerge &) = 0; - virtual void visit(const nodes::TensorReduce &) = 0; - virtual void visit(const nodes::TensorRename &) = 0; - virtual void visit(const nodes::TensorConcat &) = 0; - virtual void visit(const nodes::TensorCellCast &) = 0; - virtual void visit(const nodes::TensorCreate &) = 0; - virtual void visit(const nodes::TensorLambda &) = 0; - virtual void visit(const nodes::TensorPeek &) = 0; + virtual void visit(const nodes::TensorMap &) = 0; + virtual void visit(const nodes::TensorMapSubspaces &) = 0; + virtual void visit(const nodes::TensorJoin &) = 0; + virtual void visit(const nodes::TensorMerge &) = 0; + virtual void visit(const nodes::TensorReduce &) = 0; + virtual void visit(const nodes::TensorRename &) = 0; + virtual void visit(const nodes::TensorConcat &) = 0; + virtual void visit(const nodes::TensorCellCast &) = 0; + virtual void visit(const nodes::TensorCreate &) = 0; + virtual void visit(const nodes::TensorLambda &) = 0; + virtual void visit(const nodes::TensorPeek &) = 0; // operator nodes - virtual void visit(const nodes::Add &) = 0; - virtual void visit(const nodes::Sub &) = 0; - virtual void visit(const nodes::Mul &) = 0; - virtual void visit(const nodes::Div &) = 0; - virtual void visit(const nodes::Mod &) = 0; - virtual void visit(const nodes::Pow &) = 0; - virtual void visit(const nodes::Equal &) = 0; - virtual void visit(const nodes::NotEqual &) = 0; - virtual void visit(const nodes::Approx &) = 0; - virtual void visit(const nodes::Less &) = 0; - virtual void visit(const nodes::LessEqual &) = 0; - virtual void visit(const nodes::Greater &) = 0; - virtual void visit(const nodes::GreaterEqual &) = 0; - virtual void visit(const nodes::And &) = 0; - virtual void visit(const nodes::Or &) = 0; + virtual void visit(const nodes::Add &) = 0; + virtual void visit(const nodes::Sub &) = 0; + virtual void visit(const nodes::Mul &) = 0; + virtual void visit(const nodes::Div &) = 0; + virtual void visit(const nodes::Mod &) = 0; + virtual void visit(const nodes::Pow &) = 0; + virtual void visit(const nodes::Equal &) = 0; + virtual void visit(const nodes::NotEqual &) = 0; + virtual void visit(const nodes::Approx &) = 0; + virtual void visit(const nodes::Less &) = 0; + virtual void visit(const nodes::LessEqual &) = 0; + virtual void visit(const nodes::Greater &) = 0; + virtual void visit(const nodes::GreaterEqual &) = 0; + virtual void visit(const nodes::And &) = 0; + virtual void visit(const nodes::Or &) = 0; // call nodes - virtual void visit(const nodes::Cos &) = 0; - virtual void visit(const nodes::Sin &) = 0; - virtual void visit(const nodes::Tan &) = 0; - virtual void visit(const nodes::Cosh &) = 0; - virtual void visit(const nodes::Sinh &) = 0; - virtual void visit(const nodes::Tanh &) = 0; - virtual void visit(const nodes::Acos &) = 0; - virtual void visit(const nodes::Asin &) = 0; - virtual void visit(const nodes::Atan &) = 0; - virtual void visit(const nodes::Exp &) = 0; - virtual void visit(const nodes::Log10 &) = 0; - virtual void visit(const nodes::Log &) = 0; - virtual void visit(const nodes::Sqrt &) = 0; - virtual void visit(const nodes::Ceil &) = 0; - virtual void visit(const nodes::Fabs &) = 0; - virtual void visit(const nodes::Floor &) = 0; - virtual void visit(const nodes::Atan2 &) = 0; - virtual void visit(const nodes::Ldexp &) = 0; - virtual void visit(const nodes::Pow2 &) = 0; - virtual void visit(const nodes::Fmod &) = 0; - virtual void visit(const nodes::Min &) = 0; - virtual void visit(const nodes::Max &) = 0; - virtual void visit(const nodes::IsNan &) = 0; - virtual void visit(const nodes::Relu &) = 0; - virtual void visit(const nodes::Sigmoid &) = 0; - virtual void visit(const nodes::Elu &) = 0; - virtual void visit(const nodes::Erf &) = 0; - virtual void visit(const nodes::Bit &) = 0; - virtual void visit(const nodes::Hamming &) = 0; + virtual void visit(const nodes::Cos &) = 0; + virtual void visit(const nodes::Sin &) = 0; + virtual void visit(const nodes::Tan &) = 0; + virtual void visit(const nodes::Cosh &) = 0; + virtual void visit(const nodes::Sinh &) = 0; + virtual void visit(const nodes::Tanh &) = 0; + virtual void visit(const nodes::Acos &) = 0; + virtual void visit(const nodes::Asin &) = 0; + virtual void visit(const nodes::Atan &) = 0; + virtual void visit(const nodes::Exp &) = 0; + virtual void visit(const nodes::Log10 &) = 0; + virtual void visit(const nodes::Log &) = 0; + virtual void visit(const nodes::Sqrt &) = 0; + virtual void visit(const nodes::Ceil &) = 0; + virtual void visit(const nodes::Fabs &) = 0; + virtual void visit(const nodes::Floor &) = 0; + virtual void visit(const nodes::Atan2 &) = 0; + virtual void visit(const nodes::Ldexp &) = 0; + virtual void visit(const nodes::Pow2 &) = 0; + virtual void visit(const nodes::Fmod &) = 0; + virtual void visit(const nodes::Min &) = 0; + virtual void visit(const nodes::Max &) = 0; + virtual void visit(const nodes::IsNan &) = 0; + virtual void visit(const nodes::Relu &) = 0; + virtual void visit(const nodes::Sigmoid &) = 0; + virtual void visit(const nodes::Elu &) = 0; + virtual void visit(const nodes::Erf &) = 0; + virtual void visit(const nodes::Bit &) = 0; + virtual void visit(const nodes::Hamming &) = 0; virtual ~NodeVisitor() {} }; @@ -95,68 +96,69 @@ struct NodeVisitor { * of all types not specifically handled. **/ struct EmptyNodeVisitor : NodeVisitor { - void visit(const nodes::Number &) override {} - void visit(const nodes::Symbol &) override {} - void visit(const nodes::String &) override {} - void visit(const nodes::In &) override {} - void visit(const nodes::Neg &) override {} - void visit(const nodes::Not &) override {} - void visit(const nodes::If &) override {} - void visit(const nodes::Error &) override {} - void visit(const nodes::TensorMap &) override {} - void visit(const nodes::TensorJoin &) override {} - void visit(const nodes::TensorMerge &) override {} - void visit(const nodes::TensorReduce &) override {} - void visit(const nodes::TensorRename &) override {} - void visit(const nodes::TensorConcat &) override {} - void visit(const nodes::TensorCellCast &) override {} - void visit(const nodes::TensorCreate &) override {} - void visit(const nodes::TensorLambda &) override {} - void visit(const nodes::TensorPeek &) override {} - void visit(const nodes::Add &) override {} - void visit(const nodes::Sub &) override {} - void visit(const nodes::Mul &) override {} - void visit(const nodes::Div &) override {} - void visit(const nodes::Mod &) override {} - void visit(const nodes::Pow &) override {} - void visit(const nodes::Equal &) override {} - void visit(const nodes::NotEqual &) override {} - void visit(const nodes::Approx &) override {} - void visit(const nodes::Less &) override {} - void visit(const nodes::LessEqual &) override {} - void visit(const nodes::Greater &) override {} - void visit(const nodes::GreaterEqual &) override {} - void visit(const nodes::And &) override {} - void visit(const nodes::Or &) override {} - void visit(const nodes::Cos &) override {} - void visit(const nodes::Sin &) override {} - void visit(const nodes::Tan &) override {} - void visit(const nodes::Cosh &) override {} - void visit(const nodes::Sinh &) override {} - void visit(const nodes::Tanh &) override {} - void visit(const nodes::Acos &) override {} - void visit(const nodes::Asin &) override {} - void visit(const nodes::Atan &) override {} - void visit(const nodes::Exp &) override {} - void visit(const nodes::Log10 &) override {} - void visit(const nodes::Log &) override {} - void visit(const nodes::Sqrt &) override {} - void visit(const nodes::Ceil &) override {} - void visit(const nodes::Fabs &) override {} - void visit(const nodes::Floor &) override {} - void visit(const nodes::Atan2 &) override {} - void visit(const nodes::Ldexp &) override {} - void visit(const nodes::Pow2 &) override {} - void visit(const nodes::Fmod &) override {} - void visit(const nodes::Min &) override {} - void visit(const nodes::Max &) override {} - void visit(const nodes::IsNan &) override {} - void visit(const nodes::Relu &) override {} - void visit(const nodes::Sigmoid &) override {} - void visit(const nodes::Elu &) override {} - void visit(const nodes::Erf &) override {} - void visit(const nodes::Bit &) override {} - void visit(const nodes::Hamming &) override {} + void visit(const nodes::Number &) override {} + void visit(const nodes::Symbol &) override {} + void visit(const nodes::String &) override {} + void visit(const nodes::In &) override {} + void visit(const nodes::Neg &) override {} + void visit(const nodes::Not &) override {} + void visit(const nodes::If &) override {} + void visit(const nodes::Error &) override {} + void visit(const nodes::TensorMap &) override {} + void visit(const nodes::TensorMapSubspaces &) override {} + void visit(const nodes::TensorJoin &) override {} + void visit(const nodes::TensorMerge &) override {} + void visit(const nodes::TensorReduce &) override {} + void visit(const nodes::TensorRename &) override {} + void visit(const nodes::TensorConcat &) override {} + void visit(const nodes::TensorCellCast &) override {} + void visit(const nodes::TensorCreate &) override {} + void visit(const nodes::TensorLambda &) override {} + void visit(const nodes::TensorPeek &) override {} + void visit(const nodes::Add &) override {} + void visit(const nodes::Sub &) override {} + void visit(const nodes::Mul &) override {} + void visit(const nodes::Div &) override {} + void visit(const nodes::Mod &) override {} + void visit(const nodes::Pow &) override {} + void visit(const nodes::Equal &) override {} + void visit(const nodes::NotEqual &) override {} + void visit(const nodes::Approx &) override {} + void visit(const nodes::Less &) override {} + void visit(const nodes::LessEqual &) override {} + void visit(const nodes::Greater &) override {} + void visit(const nodes::GreaterEqual &) override {} + void visit(const nodes::And &) override {} + void visit(const nodes::Or &) override {} + void visit(const nodes::Cos &) override {} + void visit(const nodes::Sin &) override {} + void visit(const nodes::Tan &) override {} + void visit(const nodes::Cosh &) override {} + void visit(const nodes::Sinh &) override {} + void visit(const nodes::Tanh &) override {} + void visit(const nodes::Acos &) override {} + void visit(const nodes::Asin &) override {} + void visit(const nodes::Atan &) override {} + void visit(const nodes::Exp &) override {} + void visit(const nodes::Log10 &) override {} + void visit(const nodes::Log &) override {} + void visit(const nodes::Sqrt &) override {} + void visit(const nodes::Ceil &) override {} + void visit(const nodes::Fabs &) override {} + void visit(const nodes::Floor &) override {} + void visit(const nodes::Atan2 &) override {} + void visit(const nodes::Ldexp &) override {} + void visit(const nodes::Pow2 &) override {} + void visit(const nodes::Fmod &) override {} + void visit(const nodes::Min &) override {} + void visit(const nodes::Max &) override {} + void visit(const nodes::IsNan &) override {} + void visit(const nodes::Relu &) override {} + void visit(const nodes::Sigmoid &) override {} + void visit(const nodes::Elu &) override {} + void visit(const nodes::Erf &) override {} + void visit(const nodes::Bit &) override {} + void visit(const nodes::Hamming &) override {} }; } diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp index b258b6c824e..14d486aeb48 100644 --- a/eval/src/vespa/eval/eval/tensor_function.cpp +++ b/eval/src/vespa/eval/eval/tensor_function.cpp @@ -12,6 +12,7 @@ #include <vespa/eval/instruction/generic_join.h> #include <vespa/eval/instruction/generic_lambda.h> #include <vespa/eval/instruction/generic_map.h> +#include <vespa/eval/instruction/generic_map_subspaces.h> #include <vespa/eval/instruction/generic_merge.h> #include <vespa/eval/instruction/generic_peek.h> #include <vespa/eval/instruction/generic_reduce.h> @@ -172,6 +173,20 @@ Map::visit_self(vespalib::ObjectVisitor &visitor) const //----------------------------------------------------------------------------- +InterpretedFunction::Instruction +MapSubspaces::compile_self(const ValueBuilderFactory &factory, Stash &stash) const +{ + return instruction::GenericMapSubspaces::make_instruction(*this, factory, stash); +} + +void +MapSubspaces::visit_self(vespalib::ObjectVisitor &visitor) const +{ + Super::visit_self(visitor); +} + +//----------------------------------------------------------------------------- + Instruction Join::compile_self(const ValueBuilderFactory &factory, Stash &stash) const { @@ -455,6 +470,11 @@ const TensorFunction &map(const TensorFunction &child, map_fun_t function, Stash return stash.create<Map>(result_type, child, function); } +const TensorFunction &map_subspaces(const TensorFunction &child, const Function &function, NodeTypes node_types, Stash &stash) { + auto result_type = child.result_type().strip_indexed_dimensions().wrap(node_types.get_type(function.root())); + return stash.create<MapSubspaces>(result_type, child, function, std::move(node_types)); +} + const TensorFunction &join(const TensorFunction &lhs, const TensorFunction &rhs, join_fun_t function, Stash &stash) { ValueType result_type = ValueType::join(lhs.result_type(), rhs.result_type()); return stash.create<Join>(result_type, lhs, rhs, function); diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h index 24548bfae4d..2c703fbdfef 100644 --- a/eval/src/vespa/eval/eval/tensor_function.h +++ b/eval/src/vespa/eval/eval/tensor_function.h @@ -249,6 +249,29 @@ public: //----------------------------------------------------------------------------- +class MapSubspaces : public Op1 +{ + using Super = Op1; +private: + ValueType _inner_type; + std::shared_ptr<Function const> _lambda; + NodeTypes _lambda_types; +public: + MapSubspaces(const ValueType &result_type_in, const TensorFunction &child_in, const Function &lambda_in, NodeTypes lambda_types_in) + : Super(result_type_in, child_in), + _inner_type(child_in.result_type().strip_mapped_dimensions()), + _lambda(lambda_in.shared_from_this()), + _lambda_types(std::move(lambda_types_in)) {} + const ValueType &inner_type() const { return _inner_type; } + const Function &lambda() const { return *_lambda; } + const NodeTypes &types() const { return _lambda_types; } + bool result_is_mutable() const override { return true; } + InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const final override; + void visit_self(vespalib::ObjectVisitor &visitor) const override; +}; + +//----------------------------------------------------------------------------- + class Join : public Op2 { using Super = Op2; @@ -463,6 +486,7 @@ const TensorFunction &const_value(const Value &value, Stash &stash); const TensorFunction &inject(const ValueType &type, size_t param_idx, Stash &stash); const TensorFunction &reduce(const TensorFunction &child, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash); const TensorFunction &map(const TensorFunction &child, map_fun_t function, Stash &stash); +const TensorFunction &map_subspaces(const TensorFunction &child, const Function &function, NodeTypes node_types, Stash &stash); const TensorFunction &join(const TensorFunction &lhs, const TensorFunction &rhs, join_fun_t function, Stash &stash); const TensorFunction &merge(const TensorFunction &lhs, const TensorFunction &rhs, join_fun_t function, Stash &stash); const TensorFunction &concat(const TensorFunction &lhs, const TensorFunction &rhs, const vespalib::string &dimension, Stash &stash); diff --git a/eval/src/vespa/eval/eval/tensor_nodes.cpp b/eval/src/vespa/eval/eval/tensor_nodes.cpp index bfcd1f979e2..ef2718234b2 100644 --- a/eval/src/vespa/eval/eval/tensor_nodes.cpp +++ b/eval/src/vespa/eval/eval/tensor_nodes.cpp @@ -5,15 +5,16 @@ namespace vespalib::eval::nodes { -void TensorMap ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void TensorJoin ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void TensorMerge ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void TensorReduce ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void TensorRename ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void TensorConcat ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void TensorCellCast::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void TensorCreate ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void TensorLambda ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void TensorPeek ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorMap ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorMapSubspaces::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorJoin ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorMerge ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorReduce ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorRename ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorConcat ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorCellCast ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorCreate ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorLambda ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorPeek ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } } diff --git a/eval/src/vespa/eval/eval/tensor_nodes.h b/eval/src/vespa/eval/eval/tensor_nodes.h index 33f9cc5e39c..6ed19e81712 100644 --- a/eval/src/vespa/eval/eval/tensor_nodes.h +++ b/eval/src/vespa/eval/eval/tensor_nodes.h @@ -44,6 +44,36 @@ public: } }; +class TensorMapSubspaces : public Node { +private: + Node_UP _child; + std::shared_ptr<Function const> _lambda; +public: + TensorMapSubspaces(Node_UP child, std::shared_ptr<Function const> lambda) + : _child(std::move(child)), _lambda(std::move(lambda)) {} + const Node &child() const { return *_child; } + const Function &lambda() const { return *_lambda; } + vespalib::string dump(DumpContext &ctx) const override { + vespalib::string str; + str += "map_subspaces("; + str += _child->dump(ctx); + str += ","; + str += _lambda->dump_as_lambda(); + str += ")"; + return str; + } + void accept(NodeVisitor &visitor) const override; + size_t num_children() const override { return 1; } + const Node &get_child(size_t idx) const override { + (void) idx; + assert(idx == 0); + return *_child; + } + void detach_children(NodeHandler &handler) override { + handler.handle(std::move(_child)); + } +}; + class TensorJoin : public Node { private: Node_UP _lhs; diff --git a/eval/src/vespa/eval/eval/tensor_spec.cpp b/eval/src/vespa/eval/eval/tensor_spec.cpp index c9401606600..323f9eaf0fe 100644 --- a/eval/src/vespa/eval/eval/tensor_spec.cpp +++ b/eval/src/vespa/eval/eval/tensor_spec.cpp @@ -7,6 +7,7 @@ #include "value.h" #include "value_codec.h" #include "value_type.h" +#include <vespa/vespalib/util/require.h> #include <vespa/vespalib/util/overload.h> #include <vespa/vespalib/util/visit_ranges.h> #include <vespa/vespalib/util/stringfmt.h> @@ -182,19 +183,19 @@ struct NormalizeTensorSpec { size_t dense_key = 0; auto binding = entry.first.begin(); for (const auto &dim : type.dimensions()) { - assert(binding != entry.first.end()); - assert(dim.name == binding->first); - assert(dim.is_mapped() == binding->second.is_mapped()); + REQUIRE(binding != entry.first.end()); + REQUIRE(dim.name == binding->first); + REQUIRE(dim.is_mapped() == binding->second.is_mapped()); if (dim.is_mapped()) { sparse_key.push_back(binding->second.name); } else { - assert(binding->second.index < dim.size); + REQUIRE(binding->second.index < dim.size); dense_key = (dense_key * dim.size) + binding->second.index; } ++binding; } - assert(binding == entry.first.end()); - assert(dense_key < map.values_per_entry()); + REQUIRE(binding == entry.first.end()); + REQUIRE(dense_key < map.values_per_entry()); auto [tag, ignore] = map.lookup_or_add_entry(ConstArrayRef<vespalib::stringref>(sparse_key)); map.get_values(tag)[dense_key] = entry.second; } @@ -212,7 +213,7 @@ struct NormalizeTensorSpec { address.emplace(dim.name, *sparse_addr_iter++); } } - assert(sparse_addr_iter == keys.end()); + REQUIRE(sparse_addr_iter == keys.end()); for (size_t i = 0; i < values.size(); ++i) { size_t dense_key = i; for (auto dim = type.dimensions().rbegin(); @@ -364,7 +365,12 @@ TensorSpec::normalize() const if (my_type.is_error()) { return TensorSpec(my_type.to_spec()); } - return typify_invoke<1,TypifyCellType,NormalizeTensorSpec>(my_type.cell_type(), my_type, *this); + try { + return typify_invoke<1,TypifyCellType,NormalizeTensorSpec>(my_type.cell_type(), my_type, *this); + } catch (RequireFailedException &e) { + fprintf(stderr, "TensorSpec::normalize: invalid spec: %s\n", to_string().c_str()); + assert(false); // preserve crashing behavior + } } vespalib::string diff --git a/eval/src/vespa/eval/eval/test/reference_evaluation.cpp b/eval/src/vespa/eval/eval/test/reference_evaluation.cpp index 9bfa314493a..5a1fd2041dd 100644 --- a/eval/src/vespa/eval/eval/test/reference_evaluation.cpp +++ b/eval/src/vespa/eval/eval/test/reference_evaluation.cpp @@ -136,6 +136,13 @@ struct EvalNode : public NodeVisitor { result = ReferenceOperations::peek(spec, children); } + void eval_map_subspaces(const Node &node, const Node &lambda) { + auto fun = [&](const TensorSpec &subspace) { + return eval_node(lambda, {subspace}); + }; + result = ReferenceOperations::map_subspaces(eval_node(node, params), fun); + } + //------------------------------------------------------------------------- void visit(const Number &node) override { @@ -176,6 +183,9 @@ struct EvalNode : public NodeVisitor { }; eval_map(node.child(), my_op1); } + void visit(const TensorMapSubspaces &node) override { + eval_map_subspaces(node.child(), node.lambda().root()); + } void visit(const TensorJoin &node) override { auto my_op2 = [&](double a, double b) { return ReferenceEvaluation::eval(node.lambda(), {num(a), num(b)}).as_double(); diff --git a/eval/src/vespa/eval/eval/test/reference_operations.cpp b/eval/src/vespa/eval/eval/test/reference_operations.cpp index 5d79f168aaa..6771eda91e3 100644 --- a/eval/src/vespa/eval/eval/test/reference_operations.cpp +++ b/eval/src/vespa/eval/eval/test/reference_operations.cpp @@ -176,6 +176,52 @@ TensorSpec ReferenceOperations::map(const TensorSpec &in_a, map_fun_t func) { } +TensorSpec ReferenceOperations::map_subspaces(const TensorSpec &a, map_subspace_fun_t fun) { + auto type = ValueType::from_spec(a.type()); + auto outer_type = type.strip_indexed_dimensions(); + auto inner_type = type.strip_mapped_dimensions(); + auto inner_type_str = inner_type.to_spec(); + auto lambda_res_type = ValueType::from_spec(fun(TensorSpec(inner_type_str).normalize()).type()); + auto res_type = outer_type.wrap(lambda_res_type); + auto split = [](const auto &addr) { + TensorSpec::Address outer; + TensorSpec::Address inner; + for (const auto &[name, label]: addr) { + if (label.is_mapped()) { + outer.insert_or_assign(name, label); + } else { + inner.insert_or_assign(name, label); + } + } + return std::make_pair(outer, inner); + }; + auto combine = [](const auto &outer, const auto &inner) { + TensorSpec::Address addr; + for (const auto &[name, label]: outer) { + addr.insert_or_assign(name, label); + } + for (const auto &[name, label]: inner) { + addr.insert_or_assign(name, label); + } + return addr; + }; + std::map<TensorSpec::Address,TensorSpec> subspaces; + for (const auto &[addr, value]: a.cells()) { + auto [outer, inner] = split(addr); + auto &subspace = subspaces.try_emplace(outer, inner_type_str).first->second; + subspace.add(inner, value); + } + TensorSpec result(res_type.to_spec()); + for (const auto &[outer, subspace]: subspaces) { + auto mapped = fun(subspace); + for (const auto &[inner, value]: mapped.cells()) { + result.add(combine(outer, inner), value); + } + } + return result.normalize(); +} + + TensorSpec ReferenceOperations::merge(const TensorSpec &in_a, const TensorSpec &in_b, join_fun_t fun) { auto a = in_a.normalize(); auto b = in_b.normalize(); diff --git a/eval/src/vespa/eval/eval/test/reference_operations.h b/eval/src/vespa/eval/eval/test/reference_operations.h index 85aa73ec958..dd9c4b143ed 100644 --- a/eval/src/vespa/eval/eval/test/reference_operations.h +++ b/eval/src/vespa/eval/eval/test/reference_operations.h @@ -19,6 +19,7 @@ struct ReferenceOperations { using map_fun_t = std::function<double(double)>; using join_fun_t = std::function<double(double,double)>; using lambda_fun_t = std::function<double(const std::vector<size_t> &dimension_indexes)>; + using map_subspace_fun_t = std::function<TensorSpec(const TensorSpec &subspace)>; // mapping from cell address to index of child that computes the cell value using CreateSpec = tensor_function::Create::Spec; @@ -33,6 +34,7 @@ struct ReferenceOperations { static TensorSpec create(const vespalib::string &type, const CreateSpec &spec, const std::vector<TensorSpec> &children); static TensorSpec join(const TensorSpec &a, const TensorSpec &b, join_fun_t function); static TensorSpec map(const TensorSpec &a, map_fun_t func); + static TensorSpec map_subspaces(const TensorSpec &a, map_subspace_fun_t fun); static TensorSpec merge(const TensorSpec &a, const TensorSpec &b, join_fun_t fun); static TensorSpec peek(const PeekSpec &spec, const std::vector<TensorSpec> &children); static TensorSpec reduce(const TensorSpec &a, Aggr aggr, const std::vector<vespalib::string> &dims); diff --git a/eval/src/vespa/eval/eval/value_type.cpp b/eval/src/vespa/eval/eval/value_type.cpp index 1a83de9b0f9..fe70622de4e 100644 --- a/eval/src/vespa/eval/eval/value_type.cpp +++ b/eval/src/vespa/eval/eval/value_type.cpp @@ -138,6 +138,25 @@ struct Renamer { bool matched_all() const { return (match_cnt == from.size()); } }; +auto filter(const std::vector<Dimension> &dims, auto keep) { + std::vector<Dimension> result; + result.reserve(dims.size()); + for (const auto &dim: dims) { + if (keep(dim)) { + result.push_back(dim); + } + } + return result; +} + +auto strip(CellType old_cell_type, const std::vector<Dimension> &old_dims, auto discard) { + auto new_dims = filter(old_dims, [discard](const auto &dim){ return !discard(dim); }); + if (new_dims.empty()) { + return ValueType::double_type(); + } + return ValueType::make_type(old_cell_type, std::move(new_dims)); +} + } // namespace vespalib::eval::<unnamed> constexpr ValueType::Dimension::size_type ValueType::Dimension::npos; @@ -245,37 +264,19 @@ ValueType::dense_subspace_size() const std::vector<ValueType::Dimension> ValueType::nontrivial_indexed_dimensions() const { - std::vector<ValueType::Dimension> result; - for (const auto &dim: dimensions()) { - if (dim.is_indexed() && !dim.is_trivial()) { - result.push_back(dim); - } - } - return result; + return filter(_dimensions, [](const auto &dim){ return !dim.is_trivial() && dim.is_indexed(); }); } std::vector<ValueType::Dimension> ValueType::indexed_dimensions() const { - std::vector<ValueType::Dimension> result; - for (const auto &dim: dimensions()) { - if (dim.is_indexed()) { - result.push_back(dim); - } - } - return result; + return filter(_dimensions, [](const auto &dim){ return dim.is_indexed(); }); } std::vector<ValueType::Dimension> ValueType::mapped_dimensions() const { - std::vector<ValueType::Dimension> result; - for (const auto &dim: dimensions()) { - if (dim.is_mapped()) { - result.push_back(dim); - } - } - return result; + return filter(_dimensions, [](const auto &dim){ return dim.is_mapped(); }); } size_t @@ -312,6 +313,31 @@ ValueType::dimension_names() const } ValueType +ValueType::strip_mapped_dimensions() const +{ + return error_if(_error, strip(_cell_type, _dimensions, + [](const auto &dim){ return dim.is_mapped(); })); +} + +ValueType +ValueType::strip_indexed_dimensions() const +{ + return error_if(_error, strip(_cell_type, _dimensions, + [](const auto &dim){ return dim.is_indexed(); })); +} + +ValueType +ValueType::wrap(const ValueType &inner) +{ + MyJoin result(_dimensions, inner._dimensions); + auto meta = cell_meta().wrap(inner.cell_meta()); + return error_if(_error || inner._error || result.mismatch || + (count_indexed_dimensions() > 0) || + (inner.count_mapped_dimensions() > 0), + make_type(meta.cell_type, std::move(result.dimensions))); +} + +ValueType ValueType::map() const { auto meta = cell_meta().map(); diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h index b35e23ee4e6..a65dde398c2 100644 --- a/eval/src/vespa/eval/eval/value_type.h +++ b/eval/src/vespa/eval/eval/value_type.h @@ -81,6 +81,9 @@ public: } bool operator!=(const ValueType &rhs) const noexcept { return !(*this == rhs); } + ValueType strip_mapped_dimensions() const; + ValueType strip_indexed_dimensions() const; + ValueType wrap(const ValueType &inner); ValueType map() const; ValueType reduce(const std::vector<vespalib::string> &dimensions_in) const; ValueType peek(const std::vector<vespalib::string> &dimensions_in) const; diff --git a/eval/src/vespa/eval/instruction/CMakeLists.txt b/eval/src/vespa/eval/instruction/CMakeLists.txt index 22fa58a08fc..67203b6c7c8 100644 --- a/eval/src/vespa/eval/instruction/CMakeLists.txt +++ b/eval/src/vespa/eval/instruction/CMakeLists.txt @@ -24,6 +24,7 @@ vespa_add_library(eval_instruction OBJECT generic_join.cpp generic_lambda.cpp generic_map.cpp + generic_map_subspaces.cpp generic_merge.cpp generic_peek.cpp generic_reduce.cpp diff --git a/eval/src/vespa/eval/instruction/generic_map_subspaces.cpp b/eval/src/vespa/eval/instruction/generic_map_subspaces.cpp new file mode 100644 index 00000000000..1238d4f4e57 --- /dev/null +++ b/eval/src/vespa/eval/instruction/generic_map_subspaces.cpp @@ -0,0 +1,118 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "generic_map_subspaces.h" + +using namespace vespalib::eval::tensor_function; + +namespace vespalib::eval::instruction { + +using Instruction = InterpretedFunction::Instruction; +using State = InterpretedFunction::State; + +namespace { + +//----------------------------------------------------------------------------- + +struct InterpretedParams { + const ValueType &result_type; + const ValueType &inner_type; + InterpretedFunction fun; + size_t in_size; + size_t out_size; + bool direct_in; + bool direct_out; + InterpretedParams(const MapSubspaces &map_subspaces, const ValueBuilderFactory &factory) + : result_type(map_subspaces.result_type()), + inner_type(map_subspaces.inner_type()), + fun(factory, map_subspaces.lambda().root(), map_subspaces.types()), + in_size(inner_type.dense_subspace_size()), + out_size(result_type.dense_subspace_size()), + direct_in(map_subspaces.child().result_type().cell_type() == inner_type.cell_type()), + direct_out(map_subspaces.types().get_type(map_subspaces.lambda().root()).cell_type() == result_type.cell_type()) + { + assert(direct_in || (in_size == 1)); + assert(direct_out || (out_size == 1)); + } +}; + +struct ParamView final : Value, LazyParams { + const ValueType &my_type; + TypedCells my_cells; + double value; + bool direct; +public: + ParamView(const ValueType &type_in, bool direct_in) + : my_type(type_in), my_cells(), value(0.0), direct(direct_in) {} + const ValueType &type() const final override { return my_type; } + template <typename ICT> + void adjust(const ICT *cells, size_t size) { + if (direct) { + my_cells = TypedCells(cells, get_cell_type<ICT>(), size); + } else { + value = cells[0]; + my_cells = TypedCells(&value, CellType::DOUBLE, 1); + } + } + TypedCells cells() const final override { return my_cells; } + const Index &index() const final override { return TrivialIndex::get(); } + MemoryUsage get_memory_usage() const final override { return self_memory_usage<ParamView>(); } + const Value &resolve(size_t, Stash &) const final override { return *this; } +}; + +template <typename OCT> +struct ResultFiller { + OCT *dst; + bool direct; +public: + ResultFiller(OCT *dst_in, bool direct_out) + : dst(dst_in), direct(direct_out) {} + void fill(const Value &value) { + if (direct) { + auto cells = value.cells(); + memcpy(dst, cells.data, sizeof(OCT) * cells.size); + dst += cells.size; + } else { + *dst++ = value.as_double(); + } + } +}; + +template <typename ICT, typename OCT> +void my_generic_map_subspaces_op(InterpretedFunction::State &state, uint64_t param) { + const InterpretedParams ¶ms = unwrap_param<InterpretedParams>(param); + InterpretedFunction::Context ctx(params.fun); + const Value &input = state.peek(0); + const ICT *src = input.cells().typify<ICT>().data(); + size_t num_subspaces = input.index().size(); + auto res_cells = state.stash.create_uninitialized_array<OCT>(num_subspaces * params.out_size); + ResultFiller result_filler(res_cells.data(), params.direct_out); + ParamView param_view(params.inner_type, params.direct_in); + for (size_t i = 0; i < num_subspaces; ++i) { + param_view.adjust(src, params.in_size); + src += params.in_size; + result_filler.fill(params.fun.eval(ctx, param_view)); + } + state.pop_push(state.stash.create<ValueView>(params.result_type, input.index(), TypedCells(res_cells))); +} + +struct SelectGenericMapSubspacesOp { + template <typename ICT, typename OCT> static auto invoke() { + return my_generic_map_subspaces_op<ICT,OCT>; + } +}; + +//----------------------------------------------------------------------------- + +} // namespace <unnamed> + +Instruction +GenericMapSubspaces::make_instruction(const tensor_function::MapSubspaces &map_subspaces_in, + const ValueBuilderFactory &factory, Stash &stash) +{ + InterpretedParams ¶ms = stash.create<InterpretedParams>(map_subspaces_in, factory); + auto op = typify_invoke<2,TypifyCellType,SelectGenericMapSubspacesOp>(map_subspaces_in.child().result_type().cell_type(), + params.result_type.cell_type()); + return Instruction(op, wrap_param<InterpretedParams>(params)); +} + +} // namespace diff --git a/eval/src/vespa/eval/instruction/generic_map_subspaces.h b/eval/src/vespa/eval/instruction/generic_map_subspaces.h new file mode 100644 index 00000000000..f95ded60a1b --- /dev/null +++ b/eval/src/vespa/eval/instruction/generic_map_subspaces.h @@ -0,0 +1,17 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/eval/eval/interpreted_function.h> +#include <vespa/eval/eval/tensor_function.h> +#include <vespa/eval/eval/value_type.h> + +namespace vespalib::eval::instruction { + +struct GenericMapSubspaces { + static InterpretedFunction::Instruction + make_instruction(const tensor_function::MapSubspaces &map_subspaces_in, + const ValueBuilderFactory &factory, Stash &stash); +}; + +} // namespace |