summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-03-09 16:10:22 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-03-11 14:00:42 +0000
commitdcd0747caf70c53abca8dd1d6c7def42428c1f5c (patch)
tree0d8d59bf2cd50893e795abdb76ea94fbd14402ed /eval
parentb299b794b52afeacebcd875775f76e9ebf4f0dd7 (diff)
handle tensor lambda as nested function with bindings
Diffstat (limited to 'eval')
-rw-r--r--eval/CMakeLists.txt1
-rw-r--r--eval/src/tests/eval/function/function_test.cpp11
-rw-r--r--eval/src/tests/eval/tensor_lambda/CMakeLists.txt8
-rw-r--r--eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp100
-rw-r--r--eval/src/vespa/eval/eval/basic_nodes.cpp6
-rw-r--r--eval/src/vespa/eval/eval/function.cpp103
-rw-r--r--eval/src/vespa/eval/eval/key_gen.cpp3
-rw-r--r--eval/src/vespa/eval/eval/lazy_params.h9
-rw-r--r--eval/src/vespa/eval/eval/llvm/compiled_function.cpp1
-rw-r--r--eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp3
-rw-r--r--eval/src/vespa/eval/eval/make_tensor_function.cpp14
-rw-r--r--eval/src/vespa/eval/eval/node_types.cpp23
-rw-r--r--eval/src/vespa/eval/eval/node_types.h17
-rw-r--r--eval/src/vespa/eval/eval/node_visitor.h2
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.cpp80
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.h28
-rw-r--r--eval/src/vespa/eval/eval/tensor_nodes.cpp1
-rw-r--r--eval/src/vespa/eval/eval/tensor_nodes.h33
-rw-r--r--eval/src/vespa/eval/eval/tensor_spec.h4
-rw-r--r--eval/src/vespa/eval/eval/test/eval_fixture.cpp10
20 files changed, 361 insertions, 96 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt
index febb254c53e..8eb198a9a0b 100644
--- a/eval/CMakeLists.txt
+++ b/eval/CMakeLists.txt
@@ -22,6 +22,7 @@ vespa_define_module(
src/tests/eval/param_usage
src/tests/eval/simple_tensor
src/tests/eval/tensor_function
+ src/tests/eval/tensor_lambda
src/tests/eval/tensor_spec
src/tests/eval/value_cache
src/tests/eval/value_type
diff --git a/eval/src/tests/eval/function/function_test.cpp b/eval/src/tests/eval/function/function_test.cpp
index 5316217d549..11fd8abf549 100644
--- a/eval/src/tests/eval/function/function_test.cpp
+++ b/eval/src/tests/eval/function/function_test.cpp
@@ -813,8 +813,8 @@ TEST("require that tensor rename dimension lists must have equal size") {
//-----------------------------------------------------------------------------
TEST("require that tensor lambda can be parsed") {
- EXPECT_EQUAL("tensor(x[3]):{{x:0}:0,{x:1}:1,{x:2}:2}", Function::parse({}, "tensor(x[3])(x)")->dump());
- EXPECT_EQUAL("tensor(x[2],y[2]):{{x:0,y:0}:(0==0),{x:0,y:1}:(0==1),{x:1,y:0}:(1==0),{x:1,y:1}:(1==1)}",
+ EXPECT_EQUAL("tensor(x[3])(x)", Function::parse({}, "tensor(x[3])(x)")->dump());
+ EXPECT_EQUAL("tensor(x[2],y[2])(x==y)",
Function::parse({}, " tensor ( x [ 2 ] , y [ 2 ] ) ( x == y ) ")->dump());
}
@@ -825,7 +825,7 @@ TEST("require that tensor lambda requires appropriate tensor type") {
}
TEST("require that tensor lambda can use non-dimension symbols") {
- EXPECT_EQUAL("tensor(x[2]):{{x:0}:(0==a),{x:1}:(1==a)}",
+ EXPECT_EQUAL("tensor(x[2])(x==a)",
Function::parse({"a"}, "tensor(x[2])(x==a)")->dump());
}
@@ -1002,11 +1002,8 @@ TEST("require that tensor peek empty label is not allowed") {
//-----------------------------------------------------------------------------
TEST("require that nested tensor lambda using tensor peek can be parsed") {
- vespalib::string expect("tensor(x[2]):{{x:0}:tensor(y[2]):{{y:0}:((0+0)+a),{y:1}:((0+1)+a)}{y:\"0\"},"
- "{x:1}:tensor(y[2]):{{y:0}:((1+0)+a),{y:1}:((1+1)+a)}{y:\"1\"}}");
+ vespalib::string expect("tensor(x[2])(tensor(y[2])((x+y)+a){y:(x)})");
EXPECT_EQUAL(Function::parse(expect)->dump(), expect);
- auto fun = Function::parse("tensor(x[2])(tensor(y[2])(x+y+a){y:(x)})");
- EXPECT_EQUAL(fun->dump(), expect);
}
//-----------------------------------------------------------------------------
diff --git a/eval/src/tests/eval/tensor_lambda/CMakeLists.txt b/eval/src/tests/eval/tensor_lambda/CMakeLists.txt
new file mode 100644
index 00000000000..29cbbd936aa
--- /dev/null
+++ b/eval/src/tests/eval/tensor_lambda/CMakeLists.txt
@@ -0,0 +1,8 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(eval_tensor_lambda_test_app TEST
+ SOURCES
+ tensor_lambda_test.cpp
+ DEPENDS
+ vespaeval
+)
+vespa_add_test(NAME eval_tensor_lambda_test_app COMMAND eval_tensor_lambda_test_app)
diff --git a/eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp b/eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp
new file mode 100644
index 00000000000..5b0f2cf0a7e
--- /dev/null
+++ b/eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp
@@ -0,0 +1,100 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/vespalib/testkit/test_kit.h>
+#include <vespa/eval/eval/tensor_function.h>
+#include <vespa/eval/eval/simple_tensor.h>
+#include <vespa/eval/eval/simple_tensor_engine.h>
+#include <vespa/eval/tensor/default_tensor_engine.h>
+#include <vespa/eval/tensor/dense/dense_replace_type_function.h>
+#include <vespa/eval/tensor/dense/dense_fast_rename_optimizer.h>
+#include <vespa/eval/tensor/dense/dense_tensor.h>
+#include <vespa/eval/eval/test/tensor_model.hpp>
+#include <vespa/eval/eval/test/eval_fixture.h>
+#include <vespa/eval/eval/tensor_nodes.h>
+
+#include <vespa/vespalib/util/stringfmt.h>
+#include <vespa/vespalib/util/stash.h>
+
+using namespace vespalib;
+using namespace vespalib::eval;
+using namespace vespalib::eval::test;
+using namespace vespalib::tensor;
+using namespace vespalib::eval::tensor_function;
+
+const TensorEngine &prod_engine = DefaultTensorEngine::ref();
+
+EvalFixture::ParamRepo make_params() {
+ return EvalFixture::ParamRepo()
+ .add("a", spec(1))
+ .add("x3", spec({x(3)}, N()))
+ .add("x3f", spec(float_cells({x(3)}), N()));
+}
+EvalFixture::ParamRepo param_repo = make_params();
+
+void verify_dynamic(const vespalib::string &expr, const vespalib::string &expect) {
+ EvalFixture fixture(prod_engine, expr, param_repo, true);
+ EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expect, param_repo));
+ auto info = fixture.find_all<Lambda>();
+ EXPECT_EQUAL(info.size(), 1u);
+}
+
+void verify_const(const vespalib::string &expr, const vespalib::string &expect) {
+ EvalFixture fixture(prod_engine, expr, param_repo, true);
+ EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expect, param_repo));
+ auto info = fixture.find_all<ConstValue>();
+ EXPECT_EQUAL(info.size(), 1u);
+}
+
+TEST("require that simple constant tensor lambda works") {
+ TEST_DO(verify_const("tensor(x[3])(x+1)", "tensor(x[3]):[1,2,3]"));
+}
+
+TEST("require that simple dynamic tensor lambda works") {
+ TEST_DO(verify_dynamic("tensor(x[3])(x+a)", "tensor(x[3]):[1,2,3]"));
+}
+
+TEST("require that tensor lambda can be used for tensor slicing") {
+ TEST_DO(verify_dynamic("tensor(x[2])(x3{x:(x+a)})", "tensor(x[2]):[2,3]"));
+ TEST_DO(verify_dynamic("tensor(x[2])(a+x3{x:(x)})", "tensor(x[2]):[2,3]"));
+}
+
+TEST("require that tensor lambda can be used for tensor casting") {
+ TEST_DO(verify_dynamic("tensor(x[3])(x3f{x:(x)})", "tensor(x[3]):[1,2,3]"));
+ TEST_DO(verify_dynamic("tensor<float>(x[3])(x3{x:(x)})", "tensor<float>(x[3]):[1,2,3]"));
+}
+
+TEST("require that constant nested tensor lambda using tensor peek works") {
+ TEST_DO(verify_const("tensor(x[2])(tensor(y[2])((x+y)+1){y:(x)})", "tensor(x[2]):[1,3]"));
+}
+
+TEST("require that dynamic nested tensor lambda using tensor peek works") {
+ TEST_DO(verify_dynamic("tensor(x[2])(tensor(y[2])((x+y)+a){y:(x)})", "tensor(x[2]):[1,3]"));
+}
+
+TEST("require that non-double result from inner tensor lambda function fails type resolving") {
+ auto fun_a = Function::parse("tensor(x[2])(a)");
+ auto fun_b = Function::parse("tensor(x[2])(a{y:(x)})");
+ NodeTypes types_ad(*fun_a, {ValueType::from_spec("double")});
+ NodeTypes types_at(*fun_a, {ValueType::from_spec("tensor(y[2])")});
+ NodeTypes types_bd(*fun_b, {ValueType::from_spec("double")});
+ NodeTypes types_bt(*fun_b, {ValueType::from_spec("tensor(y[2])")});
+ EXPECT_EQUAL(types_ad.get_type(fun_a->root()).to_spec(), "tensor(x[2])");
+ EXPECT_EQUAL(types_at.get_type(fun_a->root()).to_spec(), "error");
+ EXPECT_EQUAL(types_bd.get_type(fun_b->root()).to_spec(), "error");
+ EXPECT_EQUAL(types_bt.get_type(fun_b->root()).to_spec(), "tensor(x[2])");
+}
+
+TEST("require that type resolving also include nodes in the inner tensor lambda function") {
+ auto fun = Function::parse("tensor(x[2])(a)");
+ NodeTypes types(*fun, {ValueType::from_spec("double")});
+ auto lambda = nodes::as<nodes::TensorLambda>(fun->root());
+ ASSERT_TRUE(lambda != nullptr);
+ EXPECT_EQUAL(types.get_type(*lambda).to_spec(), "tensor(x[2])");
+ auto symbol = nodes::as<nodes::Symbol>(lambda->lambda().root());
+ ASSERT_TRUE(symbol != nullptr);
+ EXPECT_EQUAL(types.get_type(*symbol).to_spec(), "double");
+}
+
+TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/vespa/eval/eval/basic_nodes.cpp b/eval/src/vespa/eval/eval/basic_nodes.cpp
index 1c46f2ee322..9d2b3a619fb 100644
--- a/eval/src/vespa/eval/eval/basic_nodes.cpp
+++ b/eval/src/vespa/eval/eval/basic_nodes.cpp
@@ -20,12 +20,6 @@ struct Frame {
const Node &next_child() { return node.get_child(child_idx++); }
};
-struct NoParams : LazyParams {
- const Value &resolve(size_t, Stash &) const override {
- abort();
- }
-};
-
} // namespace vespalib::eval::nodes::<unnamed>
double
diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp
index 53107fefb32..caa5e94a681 100644
--- a/eval/src/vespa/eval/eval/function.cpp
+++ b/eval/src/vespa/eval/eval/function.cpp
@@ -31,17 +31,6 @@ bool has_duplicates(const std::vector<vespalib::string> &list) {
return false;
}
-bool step_labels(std::vector<size_t> &labels, const ValueType &type) {
- for (size_t idx = labels.size(); idx-- > 0; ) {
- if (++labels[idx] < type.dimensions()[idx].size) {
- return true;
- } else {
- labels[idx] = 0;
- }
- }
- return false;
-}
-
//-----------------------------------------------------------------------------
class Params {
@@ -89,6 +78,13 @@ struct ExplicitParams : Params {
};
struct ImplicitParams : Params {
+ ImplicitParams() = default;
+ explicit ImplicitParams(const std::vector<vespalib::string> &params_in) {
+ for (const auto &param: params_in) {
+ assert(lookup(param) == UNDEF);
+ lookup_add(param);
+ }
+ }
bool implicit() const override { return true; }
size_t resolve(const vespalib::string &token) const override {
return const_cast<ImplicitParams*>(this)->lookup_add(token);
@@ -100,10 +96,8 @@ struct ImplicitParams : Params {
struct ResolveContext {
const Params &params;
const SymbolExtractor *symbol_extractor;
- const std::map<vespalib::string,size_t*> *aliases;
- ResolveContext(const Params &params_in, const SymbolExtractor *symbol_extractor_in,
- const std::map<vespalib::string,size_t*> *aliases_in)
- : params(params_in), symbol_extractor(symbol_extractor_in), aliases(aliases_in) {}
+ ResolveContext(const Params &params_in, const SymbolExtractor *symbol_extractor_in)
+ : params(params_in), symbol_extractor(symbol_extractor_in) {}
};
class ParseContext
@@ -127,7 +121,7 @@ public:
_scratch(), _failure(),
_expression_stack(), _operator_stack(),
_operator_mark(0),
- _resolve_stack({ResolveContext(params, symbol_extractor, nullptr)})
+ _resolve_stack({ResolveContext(params, symbol_extractor)})
{
if (_pos < _end) {
_curr = *_pos;
@@ -151,12 +145,11 @@ public:
}
void push_resolve_context(const Params &params) {
- _resolve_stack.emplace_back(params, nullptr, nullptr);
- }
-
- void push_resolve_context(const std::map<vespalib::string,size_t*> &aliases) {
- assert(!_resolve_stack.empty());
- _resolve_stack.emplace_back(resolver().params, resolver().symbol_extractor, &aliases);
+ if (params.implicit()) {
+ _resolve_stack.emplace_back(params, resolver().symbol_extractor);
+ } else {
+ _resolve_stack.emplace_back(params, nullptr);
+ }
}
void pop_resolve_context() {
@@ -165,21 +158,6 @@ public:
assert(!_resolve_stack.empty());
}
- bool has_alias(const vespalib::string &ident) const {
- if (auto aliases = resolver().aliases) {
- return (aliases->find(ident) != aliases->end());
- }
- return false;
- }
-
- size_t get_alias_value(const vespalib::string &ident) const {
- auto aliases = resolver().aliases;
- assert(aliases);
- auto pos = aliases->find(ident);
- assert(pos != aliases->end());
- return *(pos->second);
- }
-
void fail(const vespalib::string &msg) {
if (_failure.empty()) {
_failure = msg;
@@ -761,30 +739,25 @@ void parse_tensor_create(ParseContext &ctx, const ValueType &type,
}
void parse_tensor_lambda(ParseContext &ctx, const ValueType &type) {
+ ImplicitParams params(type.dimension_names());
+ ctx.push_resolve_context(params);
ctx.skip_spaces();
ctx.eat('(');
- ParseContext::InputMark before_expr = ctx.get_input_mark();
- std::vector<size_t> params(type.dimensions().size(), 0);
- std::map<vespalib::string,size_t*> my_aliases;
- if (auto parent_aliases = ctx.resolver().aliases) {
- my_aliases = *parent_aliases;
- }
- for (size_t i = 0; i < params.size(); ++i) {
- my_aliases.emplace(type.dimensions()[i].name, &params[i]);
- }
- ctx.push_resolve_context(my_aliases);
- nodes::TensorCreate::Spec create_spec;
- do {
- ctx.restore_input_mark(before_expr);
- TensorSpec::Address address;
- for (size_t i = 0; i < params.size(); ++i) {
- address.emplace(type.dimensions()[i].name, params[i]);
- }
- create_spec.emplace(std::move(address), get_expression(ctx));
- } while (!ctx.failed() && step_labels(params, type));
+ Node_UP lambda_root = get_expression(ctx);
ctx.eat(')');
+ ctx.skip_spaces();
ctx.pop_resolve_context();
- ctx.push_expression(std::make_unique<nodes::TensorCreate>(type, std::move(create_spec)));
+ auto param_names = params.extract();
+ std::vector<size_t> bindings;
+ for (size_t i = type.dimensions().size(); i < param_names.size(); ++i) {
+ size_t id = ctx.resolve_parameter(param_names[i]);
+ if (id == Params::UNDEF) {
+ return ctx.fail(make_string("unable to resolve: '%s'", param_names[i].c_str()));
+ }
+ bindings.push_back(id);
+ }
+ auto function = Function::create(std::move(lambda_root), std::move(param_names));
+ ctx.push_expression(std::make_unique<nodes::TensorLambda>(type, std::move(bindings), std::move(function)));
}
bool maybe_parse_tensor_generator(ParseContext &ctx) {
@@ -896,18 +869,14 @@ void parse_symbol_or_call(ParseContext &ctx) {
bool was_tensor_generate = ((name == "tensor") && maybe_parse_tensor_generator(ctx));
if (!was_tensor_generate && !maybe_parse_call(ctx, name)) {
ctx.extract_symbol(name, before_name);
- if (ctx.has_alias(name)) {
- ctx.push_expression(Node_UP(new nodes::Number(ctx.get_alias_value(name))));
+ if (name.empty()) {
+ ctx.fail("missing value");
} else {
- if (name.empty()) {
- ctx.fail("missing value");
+ size_t id = ctx.resolve_parameter(name);
+ if (id == Params::UNDEF) {
+ ctx.fail(make_string("unknown symbol: '%s'", name.c_str()));
} else {
- size_t id = ctx.resolve_parameter(name);
- if (id == Params::UNDEF) {
- ctx.fail(make_string("unknown symbol: '%s'", name.c_str()));
- } else {
- ctx.push_expression(Node_UP(new nodes::Symbol(id)));
- }
+ ctx.push_expression(Node_UP(new nodes::Symbol(id)));
}
}
}
diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp
index 46137b5878c..cd31f92f96a 100644
--- a/eval/src/vespa/eval/eval/key_gen.cpp
+++ b/eval/src/vespa/eval/eval/key_gen.cpp
@@ -42,7 +42,8 @@ struct KeyGen : public NodeVisitor, public NodeTraverser {
void visit(const TensorRename &) override { add_byte(14); } // dimensions should be part of key
void visit(const TensorConcat &) override { add_byte(15); } // dimension should be part of key
void visit(const TensorCreate &) override { add_byte(16); } // type/addr should be part of key
- void visit(const TensorPeek &) override { add_byte(17); } // addr should be part of key
+ void visit(const TensorLambda &) override { add_byte(17); } // type/lambda should be part of key
+ void visit(const TensorPeek &) override { add_byte(18); } // addr should be part of key
void visit(const Add &) override { add_byte(20); }
void visit(const Sub &) override { add_byte(21); }
void visit(const Mul &) override { add_byte(22); }
diff --git a/eval/src/vespa/eval/eval/lazy_params.h b/eval/src/vespa/eval/eval/lazy_params.h
index b6b7753eff9..d75e216571b 100644
--- a/eval/src/vespa/eval/eval/lazy_params.h
+++ b/eval/src/vespa/eval/eval/lazy_params.h
@@ -46,5 +46,14 @@ struct SimpleParams : LazyParams {
const Value &resolve(size_t idx, Stash &stash) const override;
};
+/**
+ * Simple wrapper for cases where you have no parameters.
+ **/
+struct NoParams : LazyParams {
+ const Value &resolve(size_t, Stash &) const override {
+ abort();
+ }
+};
+
} // namespace vespalib::eval
} // namespace vespalib
diff --git a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
index facfc502111..4101cf10e1f 100644
--- a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
+++ b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
@@ -133,6 +133,7 @@ CompiledFunction::detect_issues(const Function &function)
nodes::TensorRename,
nodes::TensorConcat,
nodes::TensorCreate,
+ nodes::TensorLambda,
nodes::TensorPeek>(node))
{
issues.push_back(make_string("unsupported node type: %s",
diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
index cce9838d967..89f1789e97b 100644
--- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
+++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
@@ -485,6 +485,9 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
void visit(const TensorCreate &node) override {
make_error(node.num_children());
}
+ void visit(const TensorLambda &node) override {
+ make_error(node.num_children());
+ }
void visit(const TensorPeek &node) override {
make_error(node.num_children());
}
diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp
index 849270b89a7..bbf6cadbac2 100644
--- a/eval/src/vespa/eval/eval/make_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp
@@ -121,6 +121,17 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
stack.push_back(tensor_function::create(node.type(), spec, stash));
}
+ void make_lambda(const TensorLambda &node) {
+ InterpretedFunction my_fun(tensor_engine, node.lambda().root(), node.type().dimensions().size(), types);
+ if (node.bindings().empty()) {
+ NoParams no_params;
+ TensorSpec spec = tensor_function::Lambda::create_spec_impl(node.type(), no_params, node.bindings(), my_fun);
+ make_const(node, *stash.create<Value::UP>(tensor_engine.from_spec(spec)));
+ } else {
+ stack.push_back(tensor_function::lambda(node.type(), node.bindings(), std::move(my_fun), stash));
+ }
+ }
+
void make_peek(const TensorPeek &node) {
assert(stack.size() >= node.num_children());
const tensor_function::Node &param = stack[stack.size()-node.num_children()];
@@ -221,6 +232,9 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
void visit(const TensorCreate &node) override {
make_create(node);
}
+ void visit(const TensorLambda &node) override {
+ make_lambda(node);
+ }
void visit(const TensorPeek &node) override {
make_peek(node);
}
diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp
index bf87628e301..37632e35693 100644
--- a/eval/src/vespa/eval/eval/node_types.cpp
+++ b/eval/src/vespa/eval/eval/node_types.cpp
@@ -53,6 +53,13 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser {
return state.type(node);
}
+ void import(const NodeTypes &types) {
+ types.each([&](const Node &node, const ValueType &type)
+ {
+ state.bind(type, node);
+ });
+ }
+
//-------------------------------------------------------------------------
bool check_error(const Node &node) {
@@ -121,6 +128,22 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser {
}
bind(node.type(), node);
}
+ void visit(const TensorLambda &node) override {
+ std::vector<ValueType> arg_types;
+ for (const auto &dim: node.type().dimensions()) {
+ (void) dim;
+ arg_types.push_back(ValueType::double_type());
+ }
+ for (size_t binding: node.bindings()) {
+ arg_types.push_back(param_type(binding));
+ }
+ NodeTypes lambda_types(node.lambda(), arg_types);
+ if (!lambda_types.get_type(node.lambda().root()).is_double()) {
+ return bind(ValueType::error_type(), node);
+ }
+ import(lambda_types);
+ bind(node.type(), node);
+ }
void visit(const TensorPeek &node) override {
const ValueType &param_type = type(node.param());
std::vector<vespalib::string> dimensions;
diff --git a/eval/src/vespa/eval/eval/node_types.h b/eval/src/vespa/eval/eval/node_types.h
index a9ddf371c31..a93f886a2eb 100644
--- a/eval/src/vespa/eval/eval/node_types.h
+++ b/eval/src/vespa/eval/eval/node_types.h
@@ -27,18 +27,19 @@ public:
NodeTypes();
NodeTypes(const Function &function, const std::vector<ValueType> &input_types);
const ValueType &get_type(const nodes::Node &node) const;
- template <typename P>
- bool check_types(const P &pred) const {
+ template <typename F>
+ void each(F &&f) const {
for (const auto &entry: _type_map) {
- if (!pred(entry.second)) {
- return false;
- }
+ f(*entry.first, entry.second);
}
- return (_type_map.size() > 0);
}
bool all_types_are_double() const {
- return check_types([](const ValueType &type)
- { return type.is_double(); });
+ bool all_double = true;
+ each([&all_double](const nodes::Node &, const ValueType &type)
+ {
+ all_double &= type.is_double();
+ });
+ return (all_double && (_type_map.size() > 0));
}
};
diff --git a/eval/src/vespa/eval/eval/node_visitor.h b/eval/src/vespa/eval/eval/node_visitor.h
index 8f9722858b7..d3e066c8f53 100644
--- a/eval/src/vespa/eval/eval/node_visitor.h
+++ b/eval/src/vespa/eval/eval/node_visitor.h
@@ -36,6 +36,7 @@ struct NodeVisitor {
virtual void visit(const nodes::TensorRename &) = 0;
virtual void visit(const nodes::TensorConcat &) = 0;
virtual void visit(const nodes::TensorCreate &) = 0;
+ virtual void visit(const nodes::TensorLambda &) = 0;
virtual void visit(const nodes::TensorPeek &) = 0;
// operator nodes
@@ -106,6 +107,7 @@ struct EmptyNodeVisitor : NodeVisitor {
void visit(const nodes::TensorRename &) override {}
void visit(const nodes::TensorConcat &) override {}
void visit(const nodes::TensorCreate &) override {}
+ void visit(const nodes::TensorLambda &) override {}
void visit(const nodes::TensorPeek &) override {}
void visit(const nodes::Add &) override {}
void visit(const nodes::Sub &) override {}
diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp
index 45e8094570e..889738a201d 100644
--- a/eval/src/vespa/eval/eval/tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/tensor_function.cpp
@@ -9,6 +9,7 @@
#include "visit_stuff.h"
#include "string_stuff.h"
#include <vespa/vespalib/objects/objectdumper.h>
+#include <vespa/vespalib/objects/visit.hpp>
#include <vespa/log/log.h>
LOG_SETUP(".eval.eval.tensor_function");
@@ -133,6 +134,13 @@ void op_tensor_create(State &state, uint64_t param) {
state.pop_n_push(i, result);
}
+void op_tensor_lambda(State &state, uint64_t param) {
+ const Lambda &self = unwrap_param<Lambda>(param);
+ TensorSpec spec = self.create_spec(*state.params);
+ const Value &result = *state.stash.create<Value::UP>(state.engine.from_spec(spec));
+ state.stack.emplace_back(result);
+}
+
const Value &extract_single_value(const TensorSpec &spec, const TensorSpec::Address &addr, State &state) {
auto pos = spec.cells().find(addr);
if (pos == spec.cells().end()) {
@@ -193,7 +201,7 @@ void op_tensor_peek(State &state, uint64_t param) {
state.pop_n_push(child_cnt, result);
}
-} // namespace vespalib::eval::tensor_function
+} // namespace vespalib::eval::tensor_function::<unnamed>
//-----------------------------------------------------------------------------
@@ -381,6 +389,72 @@ Create::visit_children(vespalib::ObjectVisitor &visitor) const
//-----------------------------------------------------------------------------
+namespace {
+
+bool step_labels(std::vector<size_t> &labels, const ValueType &type) {
+ for (size_t idx = labels.size(); idx-- > 0; ) {
+ if (++labels[idx] < type.dimensions()[idx].size) {
+ return true;
+ } else {
+ labels[idx] = 0;
+ }
+ }
+ return false;
+}
+
+struct ParamProxy : public LazyParams {
+ const std::vector<size_t> &labels;
+ const LazyParams &params;
+ const std::vector<size_t> &bindings;
+ ParamProxy(const std::vector<size_t> &labels_in, const LazyParams &params_in, const std::vector<size_t> &bindings_in)
+ : labels(labels_in), params(params_in), bindings(bindings_in) {}
+ const Value &resolve(size_t idx, Stash &stash) const override {
+ if (idx < labels.size()) {
+ return stash.create<DoubleValue>(labels[idx]);
+ }
+ return params.resolve(bindings[idx - labels.size()], stash);
+ }
+};
+
+}
+
+TensorSpec
+Lambda::create_spec_impl(const ValueType &type, const LazyParams &params, const std::vector<size_t> &bind, const InterpretedFunction &fun)
+{
+ std::vector<size_t> labels(type.dimensions().size(), 0);
+ ParamProxy param_proxy(labels, params, bind);
+ InterpretedFunction::Context ctx(fun);
+ TensorSpec spec(type.to_spec());
+ do {
+ TensorSpec::Address address;
+ for (size_t i = 0; i < labels.size(); ++i) {
+ address.emplace(type.dimensions()[i].name, labels[i]);
+ }
+ spec.add(std::move(address), fun.eval(ctx, param_proxy).as_double());
+ } while (step_labels(labels, type));
+ return spec;
+}
+
+InterpretedFunction::Instruction
+Lambda::compile_self(Stash &) const
+{
+ return Instruction(op_tensor_lambda, wrap_param<Lambda>(*this));
+}
+
+void
+Lambda::push_children(std::vector<Child::CREF> &) const
+{
+}
+
+void
+Lambda::visit_self(vespalib::ObjectVisitor &visitor) const
+{
+ Super::visit_self(visitor);
+ ::visit(visitor, "bindings", _bindings);
+}
+
+//-----------------------------------------------------------------------------
+
void
Peek::push_children(std::vector<Child::CREF> &children) const
{
@@ -504,6 +578,10 @@ const Node &create(const ValueType &type, const std::map<TensorSpec::Address,Nod
return stash.create<Create>(type, spec);
}
+const Node &lambda(const ValueType &type, const std::vector<size_t> &bindings, InterpretedFunction function, Stash &stash) {
+ return stash.create<Lambda>(type, bindings, std::move(function));
+}
+
const Node &peek(const Node &param, const std::map<vespalib::string, std::variant<TensorSpec::Label, Node::CREF>> &spec, Stash &stash) {
std::vector<vespalib::string> dimensions;
for (const auto &dim_spec: spec) {
diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h
index 55d27fb74ea..c4e7384abcd 100644
--- a/eval/src/vespa/eval/eval/tensor_function.h
+++ b/eval/src/vespa/eval/eval/tensor_function.h
@@ -30,11 +30,13 @@ class Tensor;
* with information about operation sequencing and intermediate
* results. Each node in the tree describes a single tensor
* operation. This is the intermediate representation of a tensor
- * function.
+ * function. Note that some nodes in the tree are already indirectly
+ * implementation-specific in that they are bound to a specific tensor
+ * engine (typically tensor constants and tensor lambdas).
*
* A tensor function will initially be created based on a Function
- * (expression AST) and associated type-resolving. In this tree, each
- * node will directly represent a single call to the tensor engine
+ * (expression AST) and associated type-resolving. In this tree, most
+ * nodes will directly represent a single call to the tensor engine
* immediate API.
*
* The generic tree will then be optimized (in-place, bottom-up) where
@@ -323,6 +325,25 @@ public:
//-----------------------------------------------------------------------------
+class Lambda : public Node
+{
+ using Super = Node;
+private:
+ std::vector<size_t> _bindings;
+ InterpretedFunction _lambda;
+public:
+ Lambda(const ValueType &result_type_in, const std::vector<size_t> &bindings_in, InterpretedFunction lambda_in)
+ : Node(result_type_in), _bindings(bindings_in), _lambda(std::move(lambda_in)) {}
+ static TensorSpec create_spec_impl(const ValueType &type, const LazyParams &params, const std::vector<size_t> &bind, const InterpretedFunction &fun);
+ TensorSpec create_spec(const LazyParams &params) const { return create_spec_impl(result_type(), params, _bindings, _lambda); }
+ bool result_is_mutable() const override { return true; }
+ InterpretedFunction::Instruction compile_self(Stash &stash) const final override;
+ void push_children(std::vector<Child::CREF> &children) const final override;
+ void visit_self(vespalib::ObjectVisitor &visitor) const override;
+};
+
+//-----------------------------------------------------------------------------
+
class Peek : public Node
{
using Super = Node;
@@ -413,6 +434,7 @@ const Node &join(const Node &lhs, const Node &rhs, join_fun_t function, Stash &s
const Node &merge(const Node &lhs, const Node &rhs, join_fun_t function, Stash &stash);
const Node &concat(const Node &lhs, const Node &rhs, const vespalib::string &dimension, Stash &stash);
const Node &create(const ValueType &type, const std::map<TensorSpec::Address, Node::CREF> &spec, Stash &stash);
+const Node &lambda(const ValueType &type, const std::vector<size_t> &bindings, InterpretedFunction function, Stash &stash);
const Node &peek(const Node &param, const std::map<vespalib::string, std::variant<TensorSpec::Label, Node::CREF>> &spec, Stash &stash);
const Node &rename(const Node &child, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash);
const Node &if_node(const Node &cond, const Node &true_child, const Node &false_child, Stash &stash);
diff --git a/eval/src/vespa/eval/eval/tensor_nodes.cpp b/eval/src/vespa/eval/eval/tensor_nodes.cpp
index 82d108300dd..5cb064ad127 100644
--- a/eval/src/vespa/eval/eval/tensor_nodes.cpp
+++ b/eval/src/vespa/eval/eval/tensor_nodes.cpp
@@ -14,6 +14,7 @@ void TensorReduce::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void TensorRename::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void TensorConcat::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void TensorCreate::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorLambda::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void TensorPeek ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
} // namespace vespalib::eval::nodes
diff --git a/eval/src/vespa/eval/eval/tensor_nodes.h b/eval/src/vespa/eval/eval/tensor_nodes.h
index daba46c1fc5..07be7e77b71 100644
--- a/eval/src/vespa/eval/eval/tensor_nodes.h
+++ b/eval/src/vespa/eval/eval/tensor_nodes.h
@@ -267,6 +267,39 @@ public:
}
};
+class TensorLambda : public Node {
+private:
+ ValueType _type;
+ std::vector<size_t> _bindings;
+ std::shared_ptr<Function const> _lambda;
+public:
+ TensorLambda(ValueType type_in, std::vector<size_t> bindings, std::shared_ptr<Function const> lambda)
+ : _type(std::move(type_in)), _bindings(std::move(bindings)), _lambda(std::move(lambda))
+ {
+ assert(_type.is_dense());
+ assert(_lambda->num_params() == (_type.dimensions().size() + _bindings.size()));
+ }
+ const ValueType &type() const { return _type; }
+ const std::vector<size_t> &bindings() const { return _bindings; }
+ const Function &lambda() const { return *_lambda; }
+ vespalib::string dump(DumpContext &) const override {
+ vespalib::string str = _type.to_spec();
+ vespalib::string expr = _lambda->dump();
+ if (starts_with(expr, "(")) {
+ str += expr;
+ } else {
+ str += "(";
+ str += expr;
+ str += ")";
+ }
+ return str;
+ }
+ void accept(NodeVisitor &visitor) const override;
+ size_t num_children() const override { return 0; }
+ const Node &get_child(size_t) const override { abort(); }
+ void detach_children(NodeHandler &) override {}
+};
+
class TensorPeek : public Node {
public:
struct MyLabel {
diff --git a/eval/src/vespa/eval/eval/tensor_spec.h b/eval/src/vespa/eval/eval/tensor_spec.h
index 25af4c7a93c..22aa47f5ddb 100644
--- a/eval/src/vespa/eval/eval/tensor_spec.h
+++ b/eval/src/vespa/eval/eval/tensor_spec.h
@@ -66,8 +66,8 @@ public:
TensorSpec(const TensorSpec &);
TensorSpec & operator = (const TensorSpec &);
~TensorSpec();
- TensorSpec &add(const Address &address, double value) {
- auto res = _cells.emplace(address, value);
+ TensorSpec &add(Address address, double value) {
+ auto res = _cells.emplace(std::move(address), value);
if (!res.second) {
res.first->second.value += value;
}
diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.cpp b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
index 1e17c8284cb..57db05107bc 100644
--- a/eval/src/vespa/eval/eval/test/eval_fixture.cpp
+++ b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
@@ -10,6 +10,14 @@ using ParamRepo = EvalFixture::ParamRepo;
namespace {
+std::shared_ptr<Function const> verify_function(std::shared_ptr<Function const> fun) {
+ if (fun->has_error()) {
+ fprintf(stderr, "eval_fixture: function parse failed: %s\n", fun->get_error().c_str());
+ }
+ ASSERT_TRUE(!fun->has_error());
+ return std::move(fun);
+}
+
NodeTypes get_types(const Function &function, const ParamRepo &param_repo) {
std::vector<ValueType> param_types;
for (size_t i = 0; i < function.num_params(); ++i) {
@@ -103,7 +111,7 @@ EvalFixture::EvalFixture(const TensorEngine &engine,
bool allow_mutable)
: _engine(engine),
_stash(),
- _function(Function::parse(expr)),
+ _function(verify_function(Function::parse(expr))),
_node_types(get_types(*_function, param_repo)),
_mutable_set(get_mutable(*_function, param_repo)),
_plain_tensor_function(make_tensor_function(_engine, _function->root(), _node_types, _stash)),