summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-11-04 15:05:10 +0000
committerArne Juul <arnej@verizonmedia.com>2020-11-04 15:05:25 +0000
commit8092b21022943489786fe27c02d4a79942382ee7 (patch)
tree767350caf3e281e41cdb950bda68211de66f41e1 /eval
parentd88bbff8050f59d19b7cad81ae6067ab8b4c1636 (diff)
fix lambda traversing and extend test
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp21
-rw-r--r--eval/src/vespa/eval/eval/node_types.cpp1
2 files changed, 18 insertions, 4 deletions
diff --git a/eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp b/eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp
index 4dc178d4b35..06749c33843 100644
--- a/eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp
+++ b/eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp
@@ -215,16 +215,31 @@ TEST("require that type resolving also include nodes in the inner tensor lambda
EXPECT_EQUAL(types.get_type(*symbol).to_spec(), "double");
}
+size_t num_exported(const NodeTypes &types) {
+ size_t cnt = 0;
+ types.each([&](const auto &, const auto &){++cnt;});
+ return cnt;
+}
+
TEST("require that type exporting also include nodes in the inner tensor lambda function") {
auto fun = Function::parse("tensor(x[2])(tensor(y[2])((x+y)+a){y:(x)})");
NodeTypes types(*fun, {ValueType::from_spec("double")});
const auto &root = fun->root();
+ NodeTypes copy = types.export_types(root);
+ EXPECT_TRUE(copy.errors().empty());
+ EXPECT_EQUAL(num_exported(types), num_exported(copy));
+
auto lambda = nodes::as<nodes::TensorLambda>(root);
ASSERT_TRUE(lambda != nullptr);
- NodeTypes outer = types.export_types(root);
- ASSERT_TRUE(outer.errors().empty());
- NodeTypes inner = outer.export_types(lambda->lambda().root());
+ NodeTypes outer = copy.export_types(lambda->lambda().root());
+ EXPECT_TRUE(outer.errors().empty());
+
+ auto inner_lambda = nodes::as<nodes::TensorLambda>(lambda->lambda().root().get_child(0));
+ ASSERT_TRUE(inner_lambda != nullptr);
+ NodeTypes inner = outer.export_types(inner_lambda->lambda().root());
EXPECT_TRUE(inner.errors().empty());
+ // [x, y, (x+y), a, (x+y)+a] are the 5 nodes:
+ EXPECT_EQUAL(num_exported(inner), 5u);
}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp
index 0fd4830a7e9..9569518417d 100644
--- a/eval/src/vespa/eval/eval/node_types.cpp
+++ b/eval/src/vespa/eval/eval/node_types.cpp
@@ -310,7 +310,6 @@ struct TypeExporter : public NodeTraverser {
bool open(const Node &node) override {
if (auto lambda = as<TensorLambda>(node)) {
lambda->lambda().root().traverse(*this);
- return false;
}
return true;
}