summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-04-02 10:58:51 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-04-03 09:38:00 +0000
commit5bd3905bb0b4a69243a6c8d3960ed47b454d44f0 (patch)
tree25ab5f6cf48c0cc887deb913023e0e1348c13e03 /eval
parent706cb2d3b2d623318ba9c0a8db0e4355448af65a (diff)
added support for exporting a subset of node types
This is needed to store type information about tensor lambda inner functions until it is needed; we want to delay making it into an interpreted function until after the actual tensor engine implementation gets a chance to come up with a better optimization.
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/eval/node_types/node_types_test.cpp48
-rw-r--r--eval/src/vespa/eval/eval/node_types.cpp32
-rw-r--r--eval/src/vespa/eval/eval/node_types.h3
3 files changed, 78 insertions, 5 deletions
diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp
index 8eaa7a80a81..7912ec213bc 100644
--- a/eval/src/tests/eval/node_types/node_types_test.cpp
+++ b/eval/src/tests/eval/node_types/node_types_test.cpp
@@ -18,6 +18,14 @@ struct TypeSpecExtractor : public vespalib::eval::SymbolExtractor {
}
};
+void print_errors(const NodeTypes &types) {
+ if (!types.errors().empty()) {
+ for (const auto &msg: types.errors()) {
+ fprintf(stderr, "type error: %s\n", msg.c_str());
+ }
+ }
+}
+
void verify(const vespalib::string &type_expr, const vespalib::string &type_spec) {
auto function = Function::parse(type_expr, TypeSpecExtractor());
if (!EXPECT_TRUE(!function->has_error())) {
@@ -29,11 +37,7 @@ void verify(const vespalib::string &type_expr, const vespalib::string &type_spec
input_types.push_back(ValueType::from_spec(function->param_name(i)));
}
NodeTypes types(*function, input_types);
- if (!types.errors().empty()) {
- for (const auto &msg: types.errors()) {
- fprintf(stderr, "type error: %s\n", msg.c_str());
- }
- }
+ print_errors(types);
ValueType expected_type = ValueType::from_spec(type_spec);
ValueType actual_type = types.get_type(function->root());
EXPECT_EQUAL(expected_type, actual_type);
@@ -306,4 +310,38 @@ TEST("require that empty type repo works as expected") {
EXPECT_FALSE(types.all_types_are_double());
}
+TEST("require that types for a subtree can be exported") {
+ auto function = Function::parse("(1+2)+3");
+ const auto &root = function->root();
+ ASSERT_EQUAL(root.num_children(), 2u);
+ const auto &n_1_2 = root.get_child(0);
+ const auto &n_3 = root.get_child(1);
+ ASSERT_EQUAL(n_1_2.num_children(), 2u);
+ const auto &n_1 = n_1_2.get_child(0);
+ const auto &n_2 = n_1_2.get_child(1);
+ NodeTypes all_types(*function, {});
+ NodeTypes some_types = all_types.export_types(n_1_2);
+ EXPECT_EQUAL(all_types.errors().size(), 0u);
+ EXPECT_EQUAL(some_types.errors().size(), 0u);
+ for (const auto node: {&root, &n_3}) {
+ EXPECT_TRUE(all_types.get_type(*node).is_double());
+ EXPECT_TRUE(some_types.get_type(*node).is_error());
+ }
+ for (const auto node: {&n_1_2, &n_1, &n_2}) {
+ EXPECT_TRUE(all_types.get_type(*node).is_double());
+ EXPECT_TRUE(some_types.get_type(*node).is_double());
+ }
+}
+
+TEST("require that export_types produces an error for missing types") {
+ auto fun1 = Function::parse("1+2");
+ auto fun2 = Function::parse("1+2");
+ NodeTypes fun1_types(*fun1, {});
+ NodeTypes bad_export = fun1_types.export_types(fun2->root());
+ EXPECT_EQUAL(bad_export.errors().size(), 1u);
+ print_errors(bad_export);
+ EXPECT_TRUE(fun1_types.get_type(fun1->root()).is_double());
+ EXPECT_TRUE(bad_export.get_type(fun2->root()).is_error());
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp
index 5fe441b7a4e..468b9a58655 100644
--- a/eval/src/vespa/eval/eval/node_types.cpp
+++ b/eval/src/vespa/eval/eval/node_types.cpp
@@ -297,6 +297,26 @@ TypeResolver::TypeResolver(const std::vector<ValueType> &params_in,
TypeResolver::~TypeResolver() {}
+struct TypeExporter : public NodeTraverser {
+ const std::map<const Node *, ValueType> &parent_type_map;
+ std::map<const Node *, ValueType> &exported_type_map;
+ size_t missing_cnt;
+ TypeExporter(const std::map<const Node *, ValueType> &parent_type_map_in,
+ std::map<const Node *, ValueType> &exported_type_map_out)
+ : parent_type_map(parent_type_map_in),
+ exported_type_map(exported_type_map_out),
+ missing_cnt(0) {}
+ bool open(const Node &) override { return true; }
+ void close(const Node &node) override {
+ auto pos = parent_type_map.find(&node);
+ if (pos != parent_type_map.end()) {
+ exported_type_map.emplace(&node, pos->second);
+ } else {
+ ++missing_cnt;
+ }
+ }
+};
+
} // namespace vespalib::eval::nodes::<unnamed>
} // namespace vespalib::eval::nodes
@@ -317,6 +337,18 @@ NodeTypes::NodeTypes(const Function &function, const std::vector<ValueType> &inp
NodeTypes::~NodeTypes() = default;
+NodeTypes
+NodeTypes::export_types(const nodes::Node &root) const
+{
+ NodeTypes exported_types;
+ nodes::TypeExporter exporter(_type_map, exported_types._type_map);
+ root.traverse(exporter);
+ if (exporter.missing_cnt > 0) {
+ exported_types._errors.push_back(fmt("[export]: %zu nodes had missing types", exporter.missing_cnt));
+ }
+ return exported_types;
+}
+
const ValueType &
NodeTypes::get_type(const nodes::Node &node) const
{
diff --git a/eval/src/vespa/eval/eval/node_types.h b/eval/src/vespa/eval/eval/node_types.h
index c072915ffb1..72332564409 100644
--- a/eval/src/vespa/eval/eval/node_types.h
+++ b/eval/src/vespa/eval/eval/node_types.h
@@ -26,9 +26,12 @@ private:
std::vector<vespalib::string> _errors;
public:
NodeTypes();
+ NodeTypes(NodeTypes &&rhs) = default;
+ NodeTypes &operator=(NodeTypes &&rhs) = default;
NodeTypes(const Function &function, const std::vector<ValueType> &input_types);
~NodeTypes();
const std::vector<vespalib::string> &errors() const { return _errors; }
+ NodeTypes export_types(const nodes::Node &root) const;
const ValueType &get_type(const nodes::Node &node) const;
template <typename F>
void each(F &&f) const {