diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-11-04 15:05:10 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-11-04 15:05:25 +0000 |
commit | 8092b21022943489786fe27c02d4a79942382ee7 (patch) | |
tree | 767350caf3e281e41cdb950bda68211de66f41e1 /eval | |
parent | d88bbff8050f59d19b7cad81ae6067ab8b4c1636 (diff) |
fix lambda traversing and extend test
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp | 21 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/node_types.cpp | 1 |
2 files changed, 18 insertions, 4 deletions
diff --git a/eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp b/eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp index 4dc178d4b35..06749c33843 100644 --- a/eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp +++ b/eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp @@ -215,16 +215,31 @@ TEST("require that type resolving also include nodes in the inner tensor lambda EXPECT_EQUAL(types.get_type(*symbol).to_spec(), "double"); } +size_t num_exported(const NodeTypes &types) { + size_t cnt = 0; + types.each([&](const auto &, const auto &){++cnt;}); + return cnt; +} + TEST("require that type exporting also include nodes in the inner tensor lambda function") { auto fun = Function::parse("tensor(x[2])(tensor(y[2])((x+y)+a){y:(x)})"); NodeTypes types(*fun, {ValueType::from_spec("double")}); const auto &root = fun->root(); + NodeTypes copy = types.export_types(root); + EXPECT_TRUE(copy.errors().empty()); + EXPECT_EQUAL(num_exported(types), num_exported(copy)); + auto lambda = nodes::as<nodes::TensorLambda>(root); ASSERT_TRUE(lambda != nullptr); - NodeTypes outer = types.export_types(root); - ASSERT_TRUE(outer.errors().empty()); - NodeTypes inner = outer.export_types(lambda->lambda().root()); + NodeTypes outer = copy.export_types(lambda->lambda().root()); + EXPECT_TRUE(outer.errors().empty()); + + auto inner_lambda = nodes::as<nodes::TensorLambda>(lambda->lambda().root().get_child(0)); + ASSERT_TRUE(inner_lambda != nullptr); + NodeTypes inner = outer.export_types(inner_lambda->lambda().root()); EXPECT_TRUE(inner.errors().empty()); + // [x, y, (x+y), a, (x+y)+a] are the 5 nodes: + EXPECT_EQUAL(num_exported(inner), 5u); } TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp index 0fd4830a7e9..9569518417d 100644 --- a/eval/src/vespa/eval/eval/node_types.cpp +++ b/eval/src/vespa/eval/eval/node_types.cpp @@ -310,7 +310,6 @@ struct TypeExporter : public NodeTraverser { bool open(const Node &node) override { if (auto lambda = as<TensorLambda>(node)) { lambda->lambda().root().traverse(*this); - return false; } return true; } |