diff options
author | Håvard Pettersen <havardpe@oath.com> | 2020-03-09 16:10:22 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2020-03-11 14:00:42 +0000 |
commit | dcd0747caf70c53abca8dd1d6c7def42428c1f5c (patch) | |
tree | 0d8d59bf2cd50893e795abdb76ea94fbd14402ed | |
parent | b299b794b52afeacebcd875775f76e9ebf4f0dd7 (diff) |
handle tensor lambda as nested function with bindings
20 files changed, 361 insertions, 96 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt index febb254c53e..8eb198a9a0b 100644 --- a/eval/CMakeLists.txt +++ b/eval/CMakeLists.txt @@ -22,6 +22,7 @@ vespa_define_module( src/tests/eval/param_usage src/tests/eval/simple_tensor src/tests/eval/tensor_function + src/tests/eval/tensor_lambda src/tests/eval/tensor_spec src/tests/eval/value_cache src/tests/eval/value_type diff --git a/eval/src/tests/eval/function/function_test.cpp b/eval/src/tests/eval/function/function_test.cpp index 5316217d549..11fd8abf549 100644 --- a/eval/src/tests/eval/function/function_test.cpp +++ b/eval/src/tests/eval/function/function_test.cpp @@ -813,8 +813,8 @@ TEST("require that tensor rename dimension lists must have equal size") { //----------------------------------------------------------------------------- TEST("require that tensor lambda can be parsed") { - EXPECT_EQUAL("tensor(x[3]):{{x:0}:0,{x:1}:1,{x:2}:2}", Function::parse({}, "tensor(x[3])(x)")->dump()); - EXPECT_EQUAL("tensor(x[2],y[2]):{{x:0,y:0}:(0==0),{x:0,y:1}:(0==1),{x:1,y:0}:(1==0),{x:1,y:1}:(1==1)}", + EXPECT_EQUAL("tensor(x[3])(x)", Function::parse({}, "tensor(x[3])(x)")->dump()); + EXPECT_EQUAL("tensor(x[2],y[2])(x==y)", Function::parse({}, " tensor ( x [ 2 ] , y [ 2 ] ) ( x == y ) ")->dump()); } @@ -825,7 +825,7 @@ TEST("require that tensor lambda requires appropriate tensor type") { } TEST("require that tensor lambda can use non-dimension symbols") { - EXPECT_EQUAL("tensor(x[2]):{{x:0}:(0==a),{x:1}:(1==a)}", + EXPECT_EQUAL("tensor(x[2])(x==a)", Function::parse({"a"}, "tensor(x[2])(x==a)")->dump()); } @@ -1002,11 +1002,8 @@ TEST("require that tensor peek empty label is not allowed") { //----------------------------------------------------------------------------- TEST("require that nested tensor lambda using tensor peek can be parsed") { - vespalib::string expect("tensor(x[2]):{{x:0}:tensor(y[2]):{{y:0}:((0+0)+a),{y:1}:((0+1)+a)}{y:\"0\"}," - "{x:1}:tensor(y[2]):{{y:0}:((1+0)+a),{y:1}:((1+1)+a)}{y:\"1\"}}"); + vespalib::string expect("tensor(x[2])(tensor(y[2])((x+y)+a){y:(x)})"); EXPECT_EQUAL(Function::parse(expect)->dump(), expect); - auto fun = Function::parse("tensor(x[2])(tensor(y[2])(x+y+a){y:(x)})"); - EXPECT_EQUAL(fun->dump(), expect); } //----------------------------------------------------------------------------- diff --git a/eval/src/tests/eval/tensor_lambda/CMakeLists.txt b/eval/src/tests/eval/tensor_lambda/CMakeLists.txt new file mode 100644 index 00000000000..29cbbd936aa --- /dev/null +++ b/eval/src/tests/eval/tensor_lambda/CMakeLists.txt @@ -0,0 +1,8 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(eval_tensor_lambda_test_app TEST + SOURCES + tensor_lambda_test.cpp + DEPENDS + vespaeval +) +vespa_add_test(NAME eval_tensor_lambda_test_app COMMAND eval_tensor_lambda_test_app) diff --git a/eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp b/eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp new file mode 100644 index 00000000000..5b0f2cf0a7e --- /dev/null +++ b/eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp @@ -0,0 +1,100 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/testkit/test_kit.h> +#include <vespa/eval/eval/tensor_function.h> +#include <vespa/eval/eval/simple_tensor.h> +#include <vespa/eval/eval/simple_tensor_engine.h> +#include <vespa/eval/tensor/default_tensor_engine.h> +#include <vespa/eval/tensor/dense/dense_replace_type_function.h> +#include <vespa/eval/tensor/dense/dense_fast_rename_optimizer.h> +#include <vespa/eval/tensor/dense/dense_tensor.h> +#include <vespa/eval/eval/test/tensor_model.hpp> +#include <vespa/eval/eval/test/eval_fixture.h> +#include <vespa/eval/eval/tensor_nodes.h> + +#include <vespa/vespalib/util/stringfmt.h> +#include <vespa/vespalib/util/stash.h> + +using namespace vespalib; +using namespace vespalib::eval; +using namespace vespalib::eval::test; +using namespace vespalib::tensor; +using namespace vespalib::eval::tensor_function; + +const TensorEngine &prod_engine = DefaultTensorEngine::ref(); + +EvalFixture::ParamRepo make_params() { + return EvalFixture::ParamRepo() + .add("a", spec(1)) + .add("x3", spec({x(3)}, N())) + .add("x3f", spec(float_cells({x(3)}), N())); +} +EvalFixture::ParamRepo param_repo = make_params(); + +void verify_dynamic(const vespalib::string &expr, const vespalib::string &expect) { + EvalFixture fixture(prod_engine, expr, param_repo, true); + EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo)); + EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expect, param_repo)); + auto info = fixture.find_all<Lambda>(); + EXPECT_EQUAL(info.size(), 1u); +} + +void verify_const(const vespalib::string &expr, const vespalib::string &expect) { + EvalFixture fixture(prod_engine, expr, param_repo, true); + EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo)); + EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expect, param_repo)); + auto info = fixture.find_all<ConstValue>(); + EXPECT_EQUAL(info.size(), 1u); +} + +TEST("require that simple constant tensor lambda works") { + TEST_DO(verify_const("tensor(x[3])(x+1)", "tensor(x[3]):[1,2,3]")); +} + +TEST("require that simple dynamic tensor lambda works") { + TEST_DO(verify_dynamic("tensor(x[3])(x+a)", "tensor(x[3]):[1,2,3]")); +} + +TEST("require that tensor lambda can be used for tensor slicing") { + TEST_DO(verify_dynamic("tensor(x[2])(x3{x:(x+a)})", "tensor(x[2]):[2,3]")); + TEST_DO(verify_dynamic("tensor(x[2])(a+x3{x:(x)})", "tensor(x[2]):[2,3]")); +} + +TEST("require that tensor lambda can be used for tensor casting") { + TEST_DO(verify_dynamic("tensor(x[3])(x3f{x:(x)})", "tensor(x[3]):[1,2,3]")); + TEST_DO(verify_dynamic("tensor<float>(x[3])(x3{x:(x)})", "tensor<float>(x[3]):[1,2,3]")); +} + +TEST("require that constant nested tensor lambda using tensor peek works") { + TEST_DO(verify_const("tensor(x[2])(tensor(y[2])((x+y)+1){y:(x)})", "tensor(x[2]):[1,3]")); +} + +TEST("require that dynamic nested tensor lambda using tensor peek works") { + TEST_DO(verify_dynamic("tensor(x[2])(tensor(y[2])((x+y)+a){y:(x)})", "tensor(x[2]):[1,3]")); +} + +TEST("require that non-double result from inner tensor lambda function fails type resolving") { + auto fun_a = Function::parse("tensor(x[2])(a)"); + auto fun_b = Function::parse("tensor(x[2])(a{y:(x)})"); + NodeTypes types_ad(*fun_a, {ValueType::from_spec("double")}); + NodeTypes types_at(*fun_a, {ValueType::from_spec("tensor(y[2])")}); + NodeTypes types_bd(*fun_b, {ValueType::from_spec("double")}); + NodeTypes types_bt(*fun_b, {ValueType::from_spec("tensor(y[2])")}); + EXPECT_EQUAL(types_ad.get_type(fun_a->root()).to_spec(), "tensor(x[2])"); + EXPECT_EQUAL(types_at.get_type(fun_a->root()).to_spec(), "error"); + EXPECT_EQUAL(types_bd.get_type(fun_b->root()).to_spec(), "error"); + EXPECT_EQUAL(types_bt.get_type(fun_b->root()).to_spec(), "tensor(x[2])"); +} + +TEST("require that type resolving also include nodes in the inner tensor lambda function") { + auto fun = Function::parse("tensor(x[2])(a)"); + NodeTypes types(*fun, {ValueType::from_spec("double")}); + auto lambda = nodes::as<nodes::TensorLambda>(fun->root()); + ASSERT_TRUE(lambda != nullptr); + EXPECT_EQUAL(types.get_type(*lambda).to_spec(), "tensor(x[2])"); + auto symbol = nodes::as<nodes::Symbol>(lambda->lambda().root()); + ASSERT_TRUE(symbol != nullptr); + EXPECT_EQUAL(types.get_type(*symbol).to_spec(), "double"); +} + +TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/vespa/eval/eval/basic_nodes.cpp b/eval/src/vespa/eval/eval/basic_nodes.cpp index 1c46f2ee322..9d2b3a619fb 100644 --- a/eval/src/vespa/eval/eval/basic_nodes.cpp +++ b/eval/src/vespa/eval/eval/basic_nodes.cpp @@ -20,12 +20,6 @@ struct Frame { const Node &next_child() { return node.get_child(child_idx++); } }; -struct NoParams : LazyParams { - const Value &resolve(size_t, Stash &) const override { - abort(); - } -}; - } // namespace vespalib::eval::nodes::<unnamed> double diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp index 53107fefb32..caa5e94a681 100644 --- a/eval/src/vespa/eval/eval/function.cpp +++ b/eval/src/vespa/eval/eval/function.cpp @@ -31,17 +31,6 @@ bool has_duplicates(const std::vector<vespalib::string> &list) { return false; } -bool step_labels(std::vector<size_t> &labels, const ValueType &type) { - for (size_t idx = labels.size(); idx-- > 0; ) { - if (++labels[idx] < type.dimensions()[idx].size) { - return true; - } else { - labels[idx] = 0; - } - } - return false; -} - //----------------------------------------------------------------------------- class Params { @@ -89,6 +78,13 @@ struct ExplicitParams : Params { }; struct ImplicitParams : Params { + ImplicitParams() = default; + explicit ImplicitParams(const std::vector<vespalib::string> ¶ms_in) { + for (const auto ¶m: params_in) { + assert(lookup(param) == UNDEF); + lookup_add(param); + } + } bool implicit() const override { return true; } size_t resolve(const vespalib::string &token) const override { return const_cast<ImplicitParams*>(this)->lookup_add(token); @@ -100,10 +96,8 @@ struct ImplicitParams : Params { struct ResolveContext { const Params ¶ms; const SymbolExtractor *symbol_extractor; - const std::map<vespalib::string,size_t*> *aliases; - ResolveContext(const Params ¶ms_in, const SymbolExtractor *symbol_extractor_in, - const std::map<vespalib::string,size_t*> *aliases_in) - : params(params_in), symbol_extractor(symbol_extractor_in), aliases(aliases_in) {} + ResolveContext(const Params ¶ms_in, const SymbolExtractor *symbol_extractor_in) + : params(params_in), symbol_extractor(symbol_extractor_in) {} }; class ParseContext @@ -127,7 +121,7 @@ public: _scratch(), _failure(), _expression_stack(), _operator_stack(), _operator_mark(0), - _resolve_stack({ResolveContext(params, symbol_extractor, nullptr)}) + _resolve_stack({ResolveContext(params, symbol_extractor)}) { if (_pos < _end) { _curr = *_pos; @@ -151,12 +145,11 @@ public: } void push_resolve_context(const Params ¶ms) { - _resolve_stack.emplace_back(params, nullptr, nullptr); - } - - void push_resolve_context(const std::map<vespalib::string,size_t*> &aliases) { - assert(!_resolve_stack.empty()); - _resolve_stack.emplace_back(resolver().params, resolver().symbol_extractor, &aliases); + if (params.implicit()) { + _resolve_stack.emplace_back(params, resolver().symbol_extractor); + } else { + _resolve_stack.emplace_back(params, nullptr); + } } void pop_resolve_context() { @@ -165,21 +158,6 @@ public: assert(!_resolve_stack.empty()); } - bool has_alias(const vespalib::string &ident) const { - if (auto aliases = resolver().aliases) { - return (aliases->find(ident) != aliases->end()); - } - return false; - } - - size_t get_alias_value(const vespalib::string &ident) const { - auto aliases = resolver().aliases; - assert(aliases); - auto pos = aliases->find(ident); - assert(pos != aliases->end()); - return *(pos->second); - } - void fail(const vespalib::string &msg) { if (_failure.empty()) { _failure = msg; @@ -761,30 +739,25 @@ void parse_tensor_create(ParseContext &ctx, const ValueType &type, } void parse_tensor_lambda(ParseContext &ctx, const ValueType &type) { + ImplicitParams params(type.dimension_names()); + ctx.push_resolve_context(params); ctx.skip_spaces(); ctx.eat('('); - ParseContext::InputMark before_expr = ctx.get_input_mark(); - std::vector<size_t> params(type.dimensions().size(), 0); - std::map<vespalib::string,size_t*> my_aliases; - if (auto parent_aliases = ctx.resolver().aliases) { - my_aliases = *parent_aliases; - } - for (size_t i = 0; i < params.size(); ++i) { - my_aliases.emplace(type.dimensions()[i].name, ¶ms[i]); - } - ctx.push_resolve_context(my_aliases); - nodes::TensorCreate::Spec create_spec; - do { - ctx.restore_input_mark(before_expr); - TensorSpec::Address address; - for (size_t i = 0; i < params.size(); ++i) { - address.emplace(type.dimensions()[i].name, params[i]); - } - create_spec.emplace(std::move(address), get_expression(ctx)); - } while (!ctx.failed() && step_labels(params, type)); + Node_UP lambda_root = get_expression(ctx); ctx.eat(')'); + ctx.skip_spaces(); ctx.pop_resolve_context(); - ctx.push_expression(std::make_unique<nodes::TensorCreate>(type, std::move(create_spec))); + auto param_names = params.extract(); + std::vector<size_t> bindings; + for (size_t i = type.dimensions().size(); i < param_names.size(); ++i) { + size_t id = ctx.resolve_parameter(param_names[i]); + if (id == Params::UNDEF) { + return ctx.fail(make_string("unable to resolve: '%s'", param_names[i].c_str())); + } + bindings.push_back(id); + } + auto function = Function::create(std::move(lambda_root), std::move(param_names)); + ctx.push_expression(std::make_unique<nodes::TensorLambda>(type, std::move(bindings), std::move(function))); } bool maybe_parse_tensor_generator(ParseContext &ctx) { @@ -896,18 +869,14 @@ void parse_symbol_or_call(ParseContext &ctx) { bool was_tensor_generate = ((name == "tensor") && maybe_parse_tensor_generator(ctx)); if (!was_tensor_generate && !maybe_parse_call(ctx, name)) { ctx.extract_symbol(name, before_name); - if (ctx.has_alias(name)) { - ctx.push_expression(Node_UP(new nodes::Number(ctx.get_alias_value(name)))); + if (name.empty()) { + ctx.fail("missing value"); } else { - if (name.empty()) { - ctx.fail("missing value"); + size_t id = ctx.resolve_parameter(name); + if (id == Params::UNDEF) { + ctx.fail(make_string("unknown symbol: '%s'", name.c_str())); } else { - size_t id = ctx.resolve_parameter(name); - if (id == Params::UNDEF) { - ctx.fail(make_string("unknown symbol: '%s'", name.c_str())); - } else { - ctx.push_expression(Node_UP(new nodes::Symbol(id))); - } + ctx.push_expression(Node_UP(new nodes::Symbol(id))); } } } diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp index 46137b5878c..cd31f92f96a 100644 --- a/eval/src/vespa/eval/eval/key_gen.cpp +++ b/eval/src/vespa/eval/eval/key_gen.cpp @@ -42,7 +42,8 @@ struct KeyGen : public NodeVisitor, public NodeTraverser { void visit(const TensorRename &) override { add_byte(14); } // dimensions should be part of key void visit(const TensorConcat &) override { add_byte(15); } // dimension should be part of key void visit(const TensorCreate &) override { add_byte(16); } // type/addr should be part of key - void visit(const TensorPeek &) override { add_byte(17); } // addr should be part of key + void visit(const TensorLambda &) override { add_byte(17); } // type/lambda should be part of key + void visit(const TensorPeek &) override { add_byte(18); } // addr should be part of key void visit(const Add &) override { add_byte(20); } void visit(const Sub &) override { add_byte(21); } void visit(const Mul &) override { add_byte(22); } diff --git a/eval/src/vespa/eval/eval/lazy_params.h b/eval/src/vespa/eval/eval/lazy_params.h index b6b7753eff9..d75e216571b 100644 --- a/eval/src/vespa/eval/eval/lazy_params.h +++ b/eval/src/vespa/eval/eval/lazy_params.h @@ -46,5 +46,14 @@ struct SimpleParams : LazyParams { const Value &resolve(size_t idx, Stash &stash) const override; }; +/** + * Simple wrapper for cases where you have no parameters. + **/ +struct NoParams : LazyParams { + const Value &resolve(size_t, Stash &) const override { + abort(); + } +}; + } // namespace vespalib::eval } // namespace vespalib diff --git a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp index facfc502111..4101cf10e1f 100644 --- a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp +++ b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp @@ -133,6 +133,7 @@ CompiledFunction::detect_issues(const Function &function) nodes::TensorRename, nodes::TensorConcat, nodes::TensorCreate, + nodes::TensorLambda, nodes::TensorPeek>(node)) { issues.push_back(make_string("unsupported node type: %s", diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp index cce9838d967..89f1789e97b 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp @@ -485,6 +485,9 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { void visit(const TensorCreate &node) override { make_error(node.num_children()); } + void visit(const TensorLambda &node) override { + make_error(node.num_children()); + } void visit(const TensorPeek &node) override { make_error(node.num_children()); } diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp index 849270b89a7..bbf6cadbac2 100644 --- a/eval/src/vespa/eval/eval/make_tensor_function.cpp +++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp @@ -121,6 +121,17 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { stack.push_back(tensor_function::create(node.type(), spec, stash)); } + void make_lambda(const TensorLambda &node) { + InterpretedFunction my_fun(tensor_engine, node.lambda().root(), node.type().dimensions().size(), types); + if (node.bindings().empty()) { + NoParams no_params; + TensorSpec spec = tensor_function::Lambda::create_spec_impl(node.type(), no_params, node.bindings(), my_fun); + make_const(node, *stash.create<Value::UP>(tensor_engine.from_spec(spec))); + } else { + stack.push_back(tensor_function::lambda(node.type(), node.bindings(), std::move(my_fun), stash)); + } + } + void make_peek(const TensorPeek &node) { assert(stack.size() >= node.num_children()); const tensor_function::Node ¶m = stack[stack.size()-node.num_children()]; @@ -221,6 +232,9 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { void visit(const TensorCreate &node) override { make_create(node); } + void visit(const TensorLambda &node) override { + make_lambda(node); + } void visit(const TensorPeek &node) override { make_peek(node); } diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp index bf87628e301..37632e35693 100644 --- a/eval/src/vespa/eval/eval/node_types.cpp +++ b/eval/src/vespa/eval/eval/node_types.cpp @@ -53,6 +53,13 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser { return state.type(node); } + void import(const NodeTypes &types) { + types.each([&](const Node &node, const ValueType &type) + { + state.bind(type, node); + }); + } + //------------------------------------------------------------------------- bool check_error(const Node &node) { @@ -121,6 +128,22 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser { } bind(node.type(), node); } + void visit(const TensorLambda &node) override { + std::vector<ValueType> arg_types; + for (const auto &dim: node.type().dimensions()) { + (void) dim; + arg_types.push_back(ValueType::double_type()); + } + for (size_t binding: node.bindings()) { + arg_types.push_back(param_type(binding)); + } + NodeTypes lambda_types(node.lambda(), arg_types); + if (!lambda_types.get_type(node.lambda().root()).is_double()) { + return bind(ValueType::error_type(), node); + } + import(lambda_types); + bind(node.type(), node); + } void visit(const TensorPeek &node) override { const ValueType ¶m_type = type(node.param()); std::vector<vespalib::string> dimensions; diff --git a/eval/src/vespa/eval/eval/node_types.h b/eval/src/vespa/eval/eval/node_types.h index a9ddf371c31..a93f886a2eb 100644 --- a/eval/src/vespa/eval/eval/node_types.h +++ b/eval/src/vespa/eval/eval/node_types.h @@ -27,18 +27,19 @@ public: NodeTypes(); NodeTypes(const Function &function, const std::vector<ValueType> &input_types); const ValueType &get_type(const nodes::Node &node) const; - template <typename P> - bool check_types(const P &pred) const { + template <typename F> + void each(F &&f) const { for (const auto &entry: _type_map) { - if (!pred(entry.second)) { - return false; - } + f(*entry.first, entry.second); } - return (_type_map.size() > 0); } bool all_types_are_double() const { - return check_types([](const ValueType &type) - { return type.is_double(); }); + bool all_double = true; + each([&all_double](const nodes::Node &, const ValueType &type) + { + all_double &= type.is_double(); + }); + return (all_double && (_type_map.size() > 0)); } }; diff --git a/eval/src/vespa/eval/eval/node_visitor.h b/eval/src/vespa/eval/eval/node_visitor.h index 8f9722858b7..d3e066c8f53 100644 --- a/eval/src/vespa/eval/eval/node_visitor.h +++ b/eval/src/vespa/eval/eval/node_visitor.h @@ -36,6 +36,7 @@ struct NodeVisitor { virtual void visit(const nodes::TensorRename &) = 0; virtual void visit(const nodes::TensorConcat &) = 0; virtual void visit(const nodes::TensorCreate &) = 0; + virtual void visit(const nodes::TensorLambda &) = 0; virtual void visit(const nodes::TensorPeek &) = 0; // operator nodes @@ -106,6 +107,7 @@ struct EmptyNodeVisitor : NodeVisitor { void visit(const nodes::TensorRename &) override {} void visit(const nodes::TensorConcat &) override {} void visit(const nodes::TensorCreate &) override {} + void visit(const nodes::TensorLambda &) override {} void visit(const nodes::TensorPeek &) override {} void visit(const nodes::Add &) override {} void visit(const nodes::Sub &) override {} diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp index 45e8094570e..889738a201d 100644 --- a/eval/src/vespa/eval/eval/tensor_function.cpp +++ b/eval/src/vespa/eval/eval/tensor_function.cpp @@ -9,6 +9,7 @@ #include "visit_stuff.h" #include "string_stuff.h" #include <vespa/vespalib/objects/objectdumper.h> +#include <vespa/vespalib/objects/visit.hpp> #include <vespa/log/log.h> LOG_SETUP(".eval.eval.tensor_function"); @@ -133,6 +134,13 @@ void op_tensor_create(State &state, uint64_t param) { state.pop_n_push(i, result); } +void op_tensor_lambda(State &state, uint64_t param) { + const Lambda &self = unwrap_param<Lambda>(param); + TensorSpec spec = self.create_spec(*state.params); + const Value &result = *state.stash.create<Value::UP>(state.engine.from_spec(spec)); + state.stack.emplace_back(result); +} + const Value &extract_single_value(const TensorSpec &spec, const TensorSpec::Address &addr, State &state) { auto pos = spec.cells().find(addr); if (pos == spec.cells().end()) { @@ -193,7 +201,7 @@ void op_tensor_peek(State &state, uint64_t param) { state.pop_n_push(child_cnt, result); } -} // namespace vespalib::eval::tensor_function +} // namespace vespalib::eval::tensor_function::<unnamed> //----------------------------------------------------------------------------- @@ -381,6 +389,72 @@ Create::visit_children(vespalib::ObjectVisitor &visitor) const //----------------------------------------------------------------------------- +namespace { + +bool step_labels(std::vector<size_t> &labels, const ValueType &type) { + for (size_t idx = labels.size(); idx-- > 0; ) { + if (++labels[idx] < type.dimensions()[idx].size) { + return true; + } else { + labels[idx] = 0; + } + } + return false; +} + +struct ParamProxy : public LazyParams { + const std::vector<size_t> &labels; + const LazyParams ¶ms; + const std::vector<size_t> &bindings; + ParamProxy(const std::vector<size_t> &labels_in, const LazyParams ¶ms_in, const std::vector<size_t> &bindings_in) + : labels(labels_in), params(params_in), bindings(bindings_in) {} + const Value &resolve(size_t idx, Stash &stash) const override { + if (idx < labels.size()) { + return stash.create<DoubleValue>(labels[idx]); + } + return params.resolve(bindings[idx - labels.size()], stash); + } +}; + +} + +TensorSpec +Lambda::create_spec_impl(const ValueType &type, const LazyParams ¶ms, const std::vector<size_t> &bind, const InterpretedFunction &fun) +{ + std::vector<size_t> labels(type.dimensions().size(), 0); + ParamProxy param_proxy(labels, params, bind); + InterpretedFunction::Context ctx(fun); + TensorSpec spec(type.to_spec()); + do { + TensorSpec::Address address; + for (size_t i = 0; i < labels.size(); ++i) { + address.emplace(type.dimensions()[i].name, labels[i]); + } + spec.add(std::move(address), fun.eval(ctx, param_proxy).as_double()); + } while (step_labels(labels, type)); + return spec; +} + +InterpretedFunction::Instruction +Lambda::compile_self(Stash &) const +{ + return Instruction(op_tensor_lambda, wrap_param<Lambda>(*this)); +} + +void +Lambda::push_children(std::vector<Child::CREF> &) const +{ +} + +void +Lambda::visit_self(vespalib::ObjectVisitor &visitor) const +{ + Super::visit_self(visitor); + ::visit(visitor, "bindings", _bindings); +} + +//----------------------------------------------------------------------------- + void Peek::push_children(std::vector<Child::CREF> &children) const { @@ -504,6 +578,10 @@ const Node &create(const ValueType &type, const std::map<TensorSpec::Address,Nod return stash.create<Create>(type, spec); } +const Node &lambda(const ValueType &type, const std::vector<size_t> &bindings, InterpretedFunction function, Stash &stash) { + return stash.create<Lambda>(type, bindings, std::move(function)); +} + const Node &peek(const Node ¶m, const std::map<vespalib::string, std::variant<TensorSpec::Label, Node::CREF>> &spec, Stash &stash) { std::vector<vespalib::string> dimensions; for (const auto &dim_spec: spec) { diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h index 55d27fb74ea..c4e7384abcd 100644 --- a/eval/src/vespa/eval/eval/tensor_function.h +++ b/eval/src/vespa/eval/eval/tensor_function.h @@ -30,11 +30,13 @@ class Tensor; * with information about operation sequencing and intermediate * results. Each node in the tree describes a single tensor * operation. This is the intermediate representation of a tensor - * function. + * function. Note that some nodes in the tree are already indirectly + * implementation-specific in that they are bound to a specific tensor + * engine (typically tensor constants and tensor lambdas). * * A tensor function will initially be created based on a Function - * (expression AST) and associated type-resolving. In this tree, each - * node will directly represent a single call to the tensor engine + * (expression AST) and associated type-resolving. In this tree, most + * nodes will directly represent a single call to the tensor engine * immediate API. * * The generic tree will then be optimized (in-place, bottom-up) where @@ -323,6 +325,25 @@ public: //----------------------------------------------------------------------------- +class Lambda : public Node +{ + using Super = Node; +private: + std::vector<size_t> _bindings; + InterpretedFunction _lambda; +public: + Lambda(const ValueType &result_type_in, const std::vector<size_t> &bindings_in, InterpretedFunction lambda_in) + : Node(result_type_in), _bindings(bindings_in), _lambda(std::move(lambda_in)) {} + static TensorSpec create_spec_impl(const ValueType &type, const LazyParams ¶ms, const std::vector<size_t> &bind, const InterpretedFunction &fun); + TensorSpec create_spec(const LazyParams ¶ms) const { return create_spec_impl(result_type(), params, _bindings, _lambda); } + bool result_is_mutable() const override { return true; } + InterpretedFunction::Instruction compile_self(Stash &stash) const final override; + void push_children(std::vector<Child::CREF> &children) const final override; + void visit_self(vespalib::ObjectVisitor &visitor) const override; +}; + +//----------------------------------------------------------------------------- + class Peek : public Node { using Super = Node; @@ -413,6 +434,7 @@ const Node &join(const Node &lhs, const Node &rhs, join_fun_t function, Stash &s const Node &merge(const Node &lhs, const Node &rhs, join_fun_t function, Stash &stash); const Node &concat(const Node &lhs, const Node &rhs, const vespalib::string &dimension, Stash &stash); const Node &create(const ValueType &type, const std::map<TensorSpec::Address, Node::CREF> &spec, Stash &stash); +const Node &lambda(const ValueType &type, const std::vector<size_t> &bindings, InterpretedFunction function, Stash &stash); const Node &peek(const Node ¶m, const std::map<vespalib::string, std::variant<TensorSpec::Label, Node::CREF>> &spec, Stash &stash); const Node &rename(const Node &child, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash); const Node &if_node(const Node &cond, const Node &true_child, const Node &false_child, Stash &stash); diff --git a/eval/src/vespa/eval/eval/tensor_nodes.cpp b/eval/src/vespa/eval/eval/tensor_nodes.cpp index 82d108300dd..5cb064ad127 100644 --- a/eval/src/vespa/eval/eval/tensor_nodes.cpp +++ b/eval/src/vespa/eval/eval/tensor_nodes.cpp @@ -14,6 +14,7 @@ void TensorReduce::accept(NodeVisitor &visitor) const { visitor.visit(*this); } void TensorRename::accept(NodeVisitor &visitor) const { visitor.visit(*this); } void TensorConcat::accept(NodeVisitor &visitor) const { visitor.visit(*this); } void TensorCreate::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorLambda::accept(NodeVisitor &visitor) const { visitor.visit(*this); } void TensorPeek ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } } // namespace vespalib::eval::nodes diff --git a/eval/src/vespa/eval/eval/tensor_nodes.h b/eval/src/vespa/eval/eval/tensor_nodes.h index daba46c1fc5..07be7e77b71 100644 --- a/eval/src/vespa/eval/eval/tensor_nodes.h +++ b/eval/src/vespa/eval/eval/tensor_nodes.h @@ -267,6 +267,39 @@ public: } }; +class TensorLambda : public Node { +private: + ValueType _type; + std::vector<size_t> _bindings; + std::shared_ptr<Function const> _lambda; +public: + TensorLambda(ValueType type_in, std::vector<size_t> bindings, std::shared_ptr<Function const> lambda) + : _type(std::move(type_in)), _bindings(std::move(bindings)), _lambda(std::move(lambda)) + { + assert(_type.is_dense()); + assert(_lambda->num_params() == (_type.dimensions().size() + _bindings.size())); + } + const ValueType &type() const { return _type; } + const std::vector<size_t> &bindings() const { return _bindings; } + const Function &lambda() const { return *_lambda; } + vespalib::string dump(DumpContext &) const override { + vespalib::string str = _type.to_spec(); + vespalib::string expr = _lambda->dump(); + if (starts_with(expr, "(")) { + str += expr; + } else { + str += "("; + str += expr; + str += ")"; + } + return str; + } + void accept(NodeVisitor &visitor) const override; + size_t num_children() const override { return 0; } + const Node &get_child(size_t) const override { abort(); } + void detach_children(NodeHandler &) override {} +}; + class TensorPeek : public Node { public: struct MyLabel { diff --git a/eval/src/vespa/eval/eval/tensor_spec.h b/eval/src/vespa/eval/eval/tensor_spec.h index 25af4c7a93c..22aa47f5ddb 100644 --- a/eval/src/vespa/eval/eval/tensor_spec.h +++ b/eval/src/vespa/eval/eval/tensor_spec.h @@ -66,8 +66,8 @@ public: TensorSpec(const TensorSpec &); TensorSpec & operator = (const TensorSpec &); ~TensorSpec(); - TensorSpec &add(const Address &address, double value) { - auto res = _cells.emplace(address, value); + TensorSpec &add(Address address, double value) { + auto res = _cells.emplace(std::move(address), value); if (!res.second) { res.first->second.value += value; } diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.cpp b/eval/src/vespa/eval/eval/test/eval_fixture.cpp index 1e17c8284cb..57db05107bc 100644 --- a/eval/src/vespa/eval/eval/test/eval_fixture.cpp +++ b/eval/src/vespa/eval/eval/test/eval_fixture.cpp @@ -10,6 +10,14 @@ using ParamRepo = EvalFixture::ParamRepo; namespace { +std::shared_ptr<Function const> verify_function(std::shared_ptr<Function const> fun) { + if (fun->has_error()) { + fprintf(stderr, "eval_fixture: function parse failed: %s\n", fun->get_error().c_str()); + } + ASSERT_TRUE(!fun->has_error()); + return std::move(fun); +} + NodeTypes get_types(const Function &function, const ParamRepo ¶m_repo) { std::vector<ValueType> param_types; for (size_t i = 0; i < function.num_params(); ++i) { @@ -103,7 +111,7 @@ EvalFixture::EvalFixture(const TensorEngine &engine, bool allow_mutable) : _engine(engine), _stash(), - _function(Function::parse(expr)), + _function(verify_function(Function::parse(expr))), _node_types(get_types(*_function, param_repo)), _mutable_set(get_mutable(*_function, param_repo)), _plain_tensor_function(make_tensor_function(_engine, _function->root(), _node_types, _stash)), |