summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2018-01-19 11:42:02 +0000
committerHåvard Pettersen <havardpe@oath.com>2018-01-19 11:42:02 +0000
commit4d14068d491ca12028c512871900750e70b1c964 (patch)
tree8b67a980602bf8778945c5953ec11f7f7c3bde66 /eval
parent90efca9673b36edce527cf02e0b663a9cca624b9 (diff)
improve type resolving for 'if'
also verify inferred types against actual result in conformance test
Diffstat (limited to 'eval')
-rw-r--r--eval/src/apps/tensor_conformance/tensor_conformance.cpp11
-rw-r--r--eval/src/tests/eval/node_types/node_types_test.cpp4
-rw-r--r--eval/src/tests/eval/tensor_function/tensor_function_test.cpp12
-rw-r--r--eval/src/tests/eval/value_type/value_type_test.cpp30
-rw-r--r--eval/src/vespa/eval/eval/node_types.cpp10
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.cpp5
-rw-r--r--eval/src/vespa/eval/eval/value_type.cpp34
-rw-r--r--eval/src/vespa/eval/eval/value_type.h1
8 files changed, 89 insertions, 18 deletions
diff --git a/eval/src/apps/tensor_conformance/tensor_conformance.cpp b/eval/src/apps/tensor_conformance/tensor_conformance.cpp
index 33c303f9574..2d7cf9b5fa0 100644
--- a/eval/src/apps/tensor_conformance/tensor_conformance.cpp
+++ b/eval/src/apps/tensor_conformance/tensor_conformance.cpp
@@ -95,7 +95,11 @@ TensorSpec eval_expr(const Inspector &test, const TensorEngine &engine, bool typ
InterpretedFunction ifun(engine, fun, types);
InterpretedFunction::Context ctx(ifun);
SimpleObjectParams params(param_refs);
- return engine.to_spec(ifun.eval(ctx, params));
+ const Value &result = ifun.eval(ctx, params);
+ if (typed) {
+ ASSERT_EQUAL(result.type(), types.get_type(fun.root()));
+ }
+ return engine.to_spec(result);
}
TensorSpec eval_expr_tf(const Inspector &test, const TensorEngine &engine) {
@@ -110,7 +114,10 @@ TensorSpec eval_expr_tf(const Inspector &test, const TensorEngine &engine) {
SimpleObjectParams params(param_refs);
NodeTypes types = NodeTypes(fun, get_types(param_values));
const auto &tfun = make_tensor_function(engine, fun.root(), types, stash);
- return engine.to_spec(tfun.eval(engine, params, stash));
+ const Value &result = tfun.eval(engine, params, stash);
+ ASSERT_EQUAL(result.type(), tfun.result_type());
+ ASSERT_EQUAL(result.type(), types.get_type(fun.root()));
+ return engine.to_spec(result);
}
//-----------------------------------------------------------------------------
diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp
index 97b34f5be3c..ce01d3f78c0 100644
--- a/eval/src/tests/eval/node_types/node_types_test.cpp
+++ b/eval/src/tests/eval/node_types/node_types_test.cpp
@@ -90,6 +90,10 @@ TEST("require that if resolves to the appropriate type") {
TEST_DO(verify("if(tensor,1,2)", "double"));
TEST_DO(verify("if(double,tensor,tensor)", "tensor"));
TEST_DO(verify("if(double,any,any)", "any"));
+ TEST_DO(verify("if(double,tensor(a[2]),tensor(a[2]))", "tensor(a[2])"));
+ TEST_DO(verify("if(double,tensor(a[2]),tensor(a[3]))", "tensor(a[])"));
+ TEST_DO(verify("if(double,tensor(a[2]),tensor(a[]))", "tensor(a[])"));
+ TEST_DO(verify("if(double,tensor(a[2]),tensor(a{}))", "tensor"));
TEST_DO(verify("if(double,tensor(a{}),tensor(a{}))", "tensor(a{})"));
TEST_DO(verify("if(double,tensor(a{}),tensor(b{}))", "tensor"));
TEST_DO(verify("if(double,tensor(a{}),tensor)", "tensor"));
diff --git a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
index 4e52fd8b47b..fb1ca3d18fe 100644
--- a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
+++ b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
@@ -250,14 +250,20 @@ TEST("require that if_node gets expected result type") {
const Node &c = inject(ValueType::from_spec("tensor(x[3])"), 0, stash);
const Node &d = inject(ValueType::from_spec("tensor(x[])"), 0, stash);
const Node &e = inject(ValueType::from_spec("tensor(y[3])"), 0, stash);
+ const Node &f = inject(ValueType::from_spec("double"), 0, stash);
+ const Node &g = inject(ValueType::from_spec("error"), 0, stash);
const Node &if_same = if_node(a, b, b, stash);
const Node &if_similar = if_node(a, b, c, stash);
const Node &if_subtype = if_node(a, b, d, stash);
const Node &if_different = if_node(a, b, e, stash);
+ const Node &if_different_types = if_node(a, b, f, stash);
+ const Node &if_with_error = if_node(a, b, g, stash);
EXPECT_EQUAL(if_same.result_type(), ValueType::from_spec("tensor(x[2])"));
- EXPECT_EQUAL(if_similar.result_type(), ValueType::from_spec("any"));
- EXPECT_EQUAL(if_subtype.result_type(), ValueType::from_spec("any"));
- EXPECT_EQUAL(if_different.result_type(), ValueType::from_spec("any"));
+ EXPECT_EQUAL(if_similar.result_type(), ValueType::from_spec("tensor(x[])"));
+ EXPECT_EQUAL(if_subtype.result_type(), ValueType::from_spec("tensor(x[])"));
+ EXPECT_EQUAL(if_different.result_type(), ValueType::from_spec("tensor"));
+ EXPECT_EQUAL(if_different_types.result_type(), ValueType::from_spec("any"));
+ EXPECT_EQUAL(if_with_error.result_type(), ValueType::from_spec("error"));
}
TEST("require that push_children works") {
diff --git a/eval/src/tests/eval/value_type/value_type_test.cpp b/eval/src/tests/eval/value_type/value_type_test.cpp
index 06a0a4d679a..ffdc601932e 100644
--- a/eval/src/tests/eval/value_type/value_type_test.cpp
+++ b/eval/src/tests/eval/value_type/value_type_test.cpp
@@ -433,4 +433,34 @@ TEST("require that types can be concatenated") {
EXPECT_EQUAL(ValueType::concat(vx_5, vy_7, "z"), cxyz_572);
}
+TEST("require that 'either' gives appropriate type") {
+ ValueType error = ValueType::error_type();
+ ValueType any = ValueType::any_type();
+ ValueType tensor = ValueType::tensor_type({});
+ ValueType scalar = ValueType::double_type();
+ ValueType vx_2 = ValueType::from_spec("tensor(x[2])");
+ ValueType vx_m = ValueType::from_spec("tensor(x{})");
+ ValueType vx_3 = ValueType::from_spec("tensor(x[3])");
+ ValueType vx_any = ValueType::from_spec("tensor(x[])");
+ ValueType vy_2 = ValueType::from_spec("tensor(y[2])");
+ ValueType mxy_22 = ValueType::from_spec("tensor(x[2],y[2])");
+ ValueType mxy_23 = ValueType::from_spec("tensor(x[2],y[3])");
+ ValueType mxy_32 = ValueType::from_spec("tensor(x[3],y[2])");
+ ValueType mxy_any2 = ValueType::from_spec("tensor(x[],y[2])");
+ ValueType mxy_2any = ValueType::from_spec("tensor(x[2],y[])");
+
+ EXPECT_EQUAL(ValueType::either(vx_2, error), error);
+ EXPECT_EQUAL(ValueType::either(error, vx_2), error);
+ EXPECT_EQUAL(ValueType::either(vx_2, vx_2), vx_2);
+ EXPECT_EQUAL(ValueType::either(vx_2, scalar), any);
+ EXPECT_EQUAL(ValueType::either(scalar, vx_2), any);
+ EXPECT_EQUAL(ValueType::either(vx_2, mxy_22), tensor);
+ EXPECT_EQUAL(ValueType::either(tensor, vx_2), tensor);
+ EXPECT_EQUAL(ValueType::either(vx_2, vy_2), tensor);
+ EXPECT_EQUAL(ValueType::either(vx_2, vx_m), tensor);
+ EXPECT_EQUAL(ValueType::either(vx_2, vx_3), vx_any);
+ EXPECT_EQUAL(ValueType::either(mxy_22, mxy_23), mxy_2any);
+ EXPECT_EQUAL(ValueType::either(mxy_32, mxy_22), mxy_any2);
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp
index 0cbc30667f0..c85c81b3752 100644
--- a/eval/src/vespa/eval/eval/node_types.cpp
+++ b/eval/src/vespa/eval/eval/node_types.cpp
@@ -94,15 +94,7 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser {
void visit(const Neg &node) override { resolve_op1(node); }
void visit(const Not &node) override { resolve_op1(node); }
void visit(const If &node) override {
- ValueType true_type = state.peek(1);
- ValueType false_type = state.peek(0);
- if (true_type == false_type) {
- bind_type(true_type, node);
- } else if (true_type.is_tensor() && false_type.is_tensor()) {
- bind_type(ValueType::tensor_type({}), node);
- } else {
- bind_type(ValueType::any_type(), node);
- }
+ bind_type(ValueType::either(state.peek(1), state.peek(0)), node);
}
void visit(const Error &node) override {
bind_type(ValueType::error_type(), node);
diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp
index 1a60dd0e898..8427cc53a16 100644
--- a/eval/src/vespa/eval/eval/tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/tensor_function.cpp
@@ -148,10 +148,7 @@ const Node &rename(const Node &child, const std::vector<vespalib::string> &from,
}
const Node &if_node(const Node &cond, const Node &true_child, const Node &false_child, Stash &stash) {
- ValueType result_type = true_child.result_type();
- if (result_type != false_child.result_type()) {
- result_type = ValueType::any_type();
- }
+ ValueType result_type = ValueType::either(true_child.result_type(), false_child.result_type());
return stash.create<If>(result_type, cond, true_child, false_child);
}
diff --git a/eval/src/vespa/eval/eval/value_type.cpp b/eval/src/vespa/eval/eval/value_type.cpp
index 1c4973a78ca..1d49fe494b7 100644
--- a/eval/src/vespa/eval/eval/value_type.cpp
+++ b/eval/src/vespa/eval/eval/value_type.cpp
@@ -265,6 +265,40 @@ ValueType::concat(const ValueType &lhs, const ValueType &rhs, const vespalib::st
return tensor_type(std::move(result.dimensions));
}
+ValueType
+ValueType::either(const ValueType &one, const ValueType &other)
+{
+ if (one.is_error() || other.is_error()) {
+ return error_type();
+ }
+ if (one == other) {
+ return one;
+ }
+ if (!one.is_tensor() || !other.is_tensor()) {
+ return any_type();
+ }
+ if (one.dimensions().size() != other.dimensions().size()) {
+ return tensor_type({});
+ }
+ std::vector<Dimension> dims;
+ for (size_t i = 0; i < one.dimensions().size(); ++i) {
+ const Dimension &a = one.dimensions()[i];
+ const Dimension &b = other.dimensions()[i];
+ if (a.name != b.name) {
+ return tensor_type({});
+ }
+ if (a.is_mapped() != b.is_mapped()) {
+ return tensor_type({});
+ }
+ if (a.size == b.size) {
+ dims.push_back(a);
+ } else {
+ dims.emplace_back(a.name, 0);
+ }
+ }
+ return tensor_type(std::move(dims));
+}
+
std::ostream &
operator<<(std::ostream &os, const ValueType &type) {
return os << type.to_spec();
diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h
index a4762acd4c0..564d6a6b84e 100644
--- a/eval/src/vespa/eval/eval/value_type.h
+++ b/eval/src/vespa/eval/eval/value_type.h
@@ -87,6 +87,7 @@ public:
vespalib::string to_spec() const;
static ValueType join(const ValueType &lhs, const ValueType &rhs);
static ValueType concat(const ValueType &lhs, const ValueType &rhs, const vespalib::string &dimension);
+ static ValueType either(const ValueType &one, const ValueType &other);
};
std::ostream &operator<<(std::ostream &os, const ValueType &type);