summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-03-20 11:16:56 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-03-20 11:16:56 +0000
commitc3c398c947d1b2f577f5dff3ee6516459d388f47 (patch)
tree1ab69b3c18f576df900b00b800164cbe82858873
parentcc659eb6a33016e412f89b797ea09b10fa4c5f3a (diff)
better tensor lambda type errors
- report actual return type when not double - import type errors from lambda function type resolving
-rw-r--r--eval/src/tests/eval/node_types/node_types_test.cpp5
-rw-r--r--eval/src/vespa/eval/eval/node_types.cpp17
2 files changed, 16 insertions, 6 deletions
diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp
index cabae6307eb..8eaa7a80a81 100644
--- a/eval/src/tests/eval/node_types/node_types_test.cpp
+++ b/eval/src/tests/eval/node_types/node_types_test.cpp
@@ -223,7 +223,7 @@ TEST("require that merge resolves to the appropriate type") {
TEST_DO(verify(strfmt(pattern, "tensor<float>(x[5])", "double"), "error"));
}
-TEST("require that lambda tensor resolves correct type") {
+TEST("require that static tensor lambda resolves correct type") {
TEST_DO(verify("tensor(x[5])(1.0)", "tensor(x[5])"));
TEST_DO(verify("tensor(x[5],y[10])(1.0)", "tensor(x[5],y[10])"));
TEST_DO(verify("tensor(x[5],y[10],z[15])(1.0)", "tensor(x[5],y[10],z[15])"));
@@ -242,11 +242,12 @@ TEST("require that tensor create resolves correct type") {
TEST_DO(verify("tensor(x[3]):{{x:0}:double,{x:1}:error,{x:2}:double}", "error"));
}
-TEST("require that tensor lambda resolves correct type") {
+TEST("require that dynamic tensor lambda resolves correct type") {
TEST_DO(verify("tensor(x[3])(error)", "error"));
TEST_DO(verify("tensor(x[3])(double)", "tensor(x[3])"));
TEST_DO(verify("tensor<float>(x[3])(double)", "tensor<float>(x[3])"));
TEST_DO(verify("tensor(x[3])(tensor(x[2]))", "error"));
+ TEST_DO(verify("tensor(x[3])(reduce(tensor(x[2])+tensor(x[4]),sum))", "error"));
}
TEST("require that tensor peek resolves correct type") {
diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp
index 924de97c470..5fe441b7a4e 100644
--- a/eval/src/vespa/eval/eval/node_types.cpp
+++ b/eval/src/vespa/eval/eval/node_types.cpp
@@ -83,7 +83,13 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser {
return state.type(node);
}
- void import(const NodeTypes &types) {
+ void import_errors(const NodeTypes &types) {
+ for (const auto &err: types.errors()) {
+ state.add_error(fmt("[lambda]: %s", err.c_str()));
+ }
+ }
+
+ void import_types(const NodeTypes &types) {
types.each([&](const Node &node, const ValueType &type)
{
state.bind(type, node);
@@ -189,10 +195,13 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser {
arg_types.push_back(param_type(binding));
}
NodeTypes lambda_types(node.lambda(), arg_types);
- if (!lambda_types.get_type(node.lambda().root()).is_double()) {
- return fail(node, "lambda function produces non-double result", false);
+ const ValueType &lambda_type = lambda_types.get_type(node.lambda().root());
+ if (!lambda_type.is_double()) {
+ import_errors(lambda_types);
+ return fail(node, fmt("lambda function has non-double result type: %s",
+ lambda_type.to_spec().c_str()), false);
}
- import(lambda_types);
+ import_types(lambda_types);
bind(node.type(), node);
}
void visit(const TensorPeek &node) override {