diff options
-rw-r--r-- | eval/src/tests/eval/node_types/node_types_test.cpp | 5 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/node_types.cpp | 17 |
2 files changed, 16 insertions, 6 deletions
diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp index cabae6307eb..8eaa7a80a81 100644 --- a/eval/src/tests/eval/node_types/node_types_test.cpp +++ b/eval/src/tests/eval/node_types/node_types_test.cpp @@ -223,7 +223,7 @@ TEST("require that merge resolves to the appropriate type") { TEST_DO(verify(strfmt(pattern, "tensor<float>(x[5])", "double"), "error")); } -TEST("require that lambda tensor resolves correct type") { +TEST("require that static tensor lambda resolves correct type") { TEST_DO(verify("tensor(x[5])(1.0)", "tensor(x[5])")); TEST_DO(verify("tensor(x[5],y[10])(1.0)", "tensor(x[5],y[10])")); TEST_DO(verify("tensor(x[5],y[10],z[15])(1.0)", "tensor(x[5],y[10],z[15])")); @@ -242,11 +242,12 @@ TEST("require that tensor create resolves correct type") { TEST_DO(verify("tensor(x[3]):{{x:0}:double,{x:1}:error,{x:2}:double}", "error")); } -TEST("require that tensor lambda resolves correct type") { +TEST("require that dynamic tensor lambda resolves correct type") { TEST_DO(verify("tensor(x[3])(error)", "error")); TEST_DO(verify("tensor(x[3])(double)", "tensor(x[3])")); TEST_DO(verify("tensor<float>(x[3])(double)", "tensor<float>(x[3])")); TEST_DO(verify("tensor(x[3])(tensor(x[2]))", "error")); + TEST_DO(verify("tensor(x[3])(reduce(tensor(x[2])+tensor(x[4]),sum))", "error")); } TEST("require that tensor peek resolves correct type") { diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp index 924de97c470..5fe441b7a4e 100644 --- a/eval/src/vespa/eval/eval/node_types.cpp +++ b/eval/src/vespa/eval/eval/node_types.cpp @@ -83,7 +83,13 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser { return state.type(node); } - void import(const NodeTypes &types) { + void import_errors(const NodeTypes &types) { + for (const auto &err: types.errors()) { + state.add_error(fmt("[lambda]: %s", err.c_str())); + } + } + + void import_types(const NodeTypes &types) { types.each([&](const Node &node, const ValueType &type) { state.bind(type, node); @@ -189,10 +195,13 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser { arg_types.push_back(param_type(binding)); } NodeTypes lambda_types(node.lambda(), arg_types); - if (!lambda_types.get_type(node.lambda().root()).is_double()) { - return fail(node, "lambda function produces non-double result", false); + const ValueType &lambda_type = lambda_types.get_type(node.lambda().root()); + if (!lambda_type.is_double()) { + import_errors(lambda_types); + return fail(node, fmt("lambda function has non-double result type: %s", + lambda_type.to_spec().c_str()), false); } - import(lambda_types); + import_types(lambda_types); bind(node.type(), node); } void visit(const TensorPeek &node) override { |