diff options
author | Håvard Pettersen <havardpe@oath.com> | 2020-04-02 10:58:51 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2020-04-03 09:38:00 +0000 |
commit | 5bd3905bb0b4a69243a6c8d3960ed47b454d44f0 (patch) | |
tree | 25ab5f6cf48c0cc887deb913023e0e1348c13e03 /eval | |
parent | 706cb2d3b2d623318ba9c0a8db0e4355448af65a (diff) |
added support for exporting a subset of node types
This is needed to store type information about tensor lambda inner
functions until it is needed; we want to delay making it into an
interpreted function until after the actual tensor engine implementation
gets a chance to come up with a better optimization.
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/tests/eval/node_types/node_types_test.cpp | 48 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/node_types.cpp | 32 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/node_types.h | 3 |
3 files changed, 78 insertions, 5 deletions
diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp index 8eaa7a80a81..7912ec213bc 100644 --- a/eval/src/tests/eval/node_types/node_types_test.cpp +++ b/eval/src/tests/eval/node_types/node_types_test.cpp @@ -18,6 +18,14 @@ struct TypeSpecExtractor : public vespalib::eval::SymbolExtractor { } }; +void print_errors(const NodeTypes &types) { + if (!types.errors().empty()) { + for (const auto &msg: types.errors()) { + fprintf(stderr, "type error: %s\n", msg.c_str()); + } + } +} + void verify(const vespalib::string &type_expr, const vespalib::string &type_spec) { auto function = Function::parse(type_expr, TypeSpecExtractor()); if (!EXPECT_TRUE(!function->has_error())) { @@ -29,11 +37,7 @@ void verify(const vespalib::string &type_expr, const vespalib::string &type_spec input_types.push_back(ValueType::from_spec(function->param_name(i))); } NodeTypes types(*function, input_types); - if (!types.errors().empty()) { - for (const auto &msg: types.errors()) { - fprintf(stderr, "type error: %s\n", msg.c_str()); - } - } + print_errors(types); ValueType expected_type = ValueType::from_spec(type_spec); ValueType actual_type = types.get_type(function->root()); EXPECT_EQUAL(expected_type, actual_type); @@ -306,4 +310,38 @@ TEST("require that empty type repo works as expected") { EXPECT_FALSE(types.all_types_are_double()); } +TEST("require that types for a subtree can be exported") { + auto function = Function::parse("(1+2)+3"); + const auto &root = function->root(); + ASSERT_EQUAL(root.num_children(), 2u); + const auto &n_1_2 = root.get_child(0); + const auto &n_3 = root.get_child(1); + ASSERT_EQUAL(n_1_2.num_children(), 2u); + const auto &n_1 = n_1_2.get_child(0); + const auto &n_2 = n_1_2.get_child(1); + NodeTypes all_types(*function, {}); + NodeTypes some_types = all_types.export_types(n_1_2); + EXPECT_EQUAL(all_types.errors().size(), 0u); + EXPECT_EQUAL(some_types.errors().size(), 0u); + for (const auto node: {&root, &n_3}) { + EXPECT_TRUE(all_types.get_type(*node).is_double()); + EXPECT_TRUE(some_types.get_type(*node).is_error()); + } + for (const auto node: {&n_1_2, &n_1, &n_2}) { + EXPECT_TRUE(all_types.get_type(*node).is_double()); + EXPECT_TRUE(some_types.get_type(*node).is_double()); + } +} + +TEST("require that export_types produces an error for missing types") { + auto fun1 = Function::parse("1+2"); + auto fun2 = Function::parse("1+2"); + NodeTypes fun1_types(*fun1, {}); + NodeTypes bad_export = fun1_types.export_types(fun2->root()); + EXPECT_EQUAL(bad_export.errors().size(), 1u); + print_errors(bad_export); + EXPECT_TRUE(fun1_types.get_type(fun1->root()).is_double()); + EXPECT_TRUE(bad_export.get_type(fun2->root()).is_error()); +} + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp index 5fe441b7a4e..468b9a58655 100644 --- a/eval/src/vespa/eval/eval/node_types.cpp +++ b/eval/src/vespa/eval/eval/node_types.cpp @@ -297,6 +297,26 @@ TypeResolver::TypeResolver(const std::vector<ValueType> ¶ms_in, TypeResolver::~TypeResolver() {} +struct TypeExporter : public NodeTraverser { + const std::map<const Node *, ValueType> &parent_type_map; + std::map<const Node *, ValueType> &exported_type_map; + size_t missing_cnt; + TypeExporter(const std::map<const Node *, ValueType> &parent_type_map_in, + std::map<const Node *, ValueType> &exported_type_map_out) + : parent_type_map(parent_type_map_in), + exported_type_map(exported_type_map_out), + missing_cnt(0) {} + bool open(const Node &) override { return true; } + void close(const Node &node) override { + auto pos = parent_type_map.find(&node); + if (pos != parent_type_map.end()) { + exported_type_map.emplace(&node, pos->second); + } else { + ++missing_cnt; + } + } +}; + } // namespace vespalib::eval::nodes::<unnamed> } // namespace vespalib::eval::nodes @@ -317,6 +337,18 @@ NodeTypes::NodeTypes(const Function &function, const std::vector<ValueType> &inp NodeTypes::~NodeTypes() = default; +NodeTypes +NodeTypes::export_types(const nodes::Node &root) const +{ + NodeTypes exported_types; + nodes::TypeExporter exporter(_type_map, exported_types._type_map); + root.traverse(exporter); + if (exporter.missing_cnt > 0) { + exported_types._errors.push_back(fmt("[export]: %zu nodes had missing types", exporter.missing_cnt)); + } + return exported_types; +} + const ValueType & NodeTypes::get_type(const nodes::Node &node) const { diff --git a/eval/src/vespa/eval/eval/node_types.h b/eval/src/vespa/eval/eval/node_types.h index c072915ffb1..72332564409 100644 --- a/eval/src/vespa/eval/eval/node_types.h +++ b/eval/src/vespa/eval/eval/node_types.h @@ -26,9 +26,12 @@ private: std::vector<vespalib::string> _errors; public: NodeTypes(); + NodeTypes(NodeTypes &&rhs) = default; + NodeTypes &operator=(NodeTypes &&rhs) = default; NodeTypes(const Function &function, const std::vector<ValueType> &input_types); ~NodeTypes(); const std::vector<vespalib::string> &errors() const { return _errors; } + NodeTypes export_types(const nodes::Node &root) const; const ValueType &get_type(const nodes::Node &node) const; template <typename F> void each(F &&f) const { |