diff options
-rw-r--r-- | eval/src/apps/tensor_conformance/tensor_conformance.cpp | 11 | ||||
-rw-r--r-- | eval/src/tests/eval/node_types/node_types_test.cpp | 4 | ||||
-rw-r--r-- | eval/src/tests/eval/tensor_function/tensor_function_test.cpp | 12 | ||||
-rw-r--r-- | eval/src/tests/eval/value_type/value_type_test.cpp | 30 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/node_types.cpp | 10 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/tensor_function.cpp | 5 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/value_type.cpp | 34 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/value_type.h | 1 |
8 files changed, 89 insertions, 18 deletions
diff --git a/eval/src/apps/tensor_conformance/tensor_conformance.cpp b/eval/src/apps/tensor_conformance/tensor_conformance.cpp index 33c303f9574..2d7cf9b5fa0 100644 --- a/eval/src/apps/tensor_conformance/tensor_conformance.cpp +++ b/eval/src/apps/tensor_conformance/tensor_conformance.cpp @@ -95,7 +95,11 @@ TensorSpec eval_expr(const Inspector &test, const TensorEngine &engine, bool typ InterpretedFunction ifun(engine, fun, types); InterpretedFunction::Context ctx(ifun); SimpleObjectParams params(param_refs); - return engine.to_spec(ifun.eval(ctx, params)); + const Value &result = ifun.eval(ctx, params); + if (typed) { + ASSERT_EQUAL(result.type(), types.get_type(fun.root())); + } + return engine.to_spec(result); } TensorSpec eval_expr_tf(const Inspector &test, const TensorEngine &engine) { @@ -110,7 +114,10 @@ TensorSpec eval_expr_tf(const Inspector &test, const TensorEngine &engine) { SimpleObjectParams params(param_refs); NodeTypes types = NodeTypes(fun, get_types(param_values)); const auto &tfun = make_tensor_function(engine, fun.root(), types, stash); - return engine.to_spec(tfun.eval(engine, params, stash)); + const Value &result = tfun.eval(engine, params, stash); + ASSERT_EQUAL(result.type(), tfun.result_type()); + ASSERT_EQUAL(result.type(), types.get_type(fun.root())); + return engine.to_spec(result); } //----------------------------------------------------------------------------- diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp index 97b34f5be3c..ce01d3f78c0 100644 --- a/eval/src/tests/eval/node_types/node_types_test.cpp +++ b/eval/src/tests/eval/node_types/node_types_test.cpp @@ -90,6 +90,10 @@ TEST("require that if resolves to the appropriate type") { TEST_DO(verify("if(tensor,1,2)", "double")); TEST_DO(verify("if(double,tensor,tensor)", "tensor")); TEST_DO(verify("if(double,any,any)", "any")); + TEST_DO(verify("if(double,tensor(a[2]),tensor(a[2]))", "tensor(a[2])")); + TEST_DO(verify("if(double,tensor(a[2]),tensor(a[3]))", "tensor(a[])")); + TEST_DO(verify("if(double,tensor(a[2]),tensor(a[]))", "tensor(a[])")); + TEST_DO(verify("if(double,tensor(a[2]),tensor(a{}))", "tensor")); TEST_DO(verify("if(double,tensor(a{}),tensor(a{}))", "tensor(a{})")); TEST_DO(verify("if(double,tensor(a{}),tensor(b{}))", "tensor")); TEST_DO(verify("if(double,tensor(a{}),tensor)", "tensor")); diff --git a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp index 4e52fd8b47b..fb1ca3d18fe 100644 --- a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp +++ b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp @@ -250,14 +250,20 @@ TEST("require that if_node gets expected result type") { const Node &c = inject(ValueType::from_spec("tensor(x[3])"), 0, stash); const Node &d = inject(ValueType::from_spec("tensor(x[])"), 0, stash); const Node &e = inject(ValueType::from_spec("tensor(y[3])"), 0, stash); + const Node &f = inject(ValueType::from_spec("double"), 0, stash); + const Node &g = inject(ValueType::from_spec("error"), 0, stash); const Node &if_same = if_node(a, b, b, stash); const Node &if_similar = if_node(a, b, c, stash); const Node &if_subtype = if_node(a, b, d, stash); const Node &if_different = if_node(a, b, e, stash); + const Node &if_different_types = if_node(a, b, f, stash); + const Node &if_with_error = if_node(a, b, g, stash); EXPECT_EQUAL(if_same.result_type(), ValueType::from_spec("tensor(x[2])")); - EXPECT_EQUAL(if_similar.result_type(), ValueType::from_spec("any")); - EXPECT_EQUAL(if_subtype.result_type(), ValueType::from_spec("any")); - EXPECT_EQUAL(if_different.result_type(), ValueType::from_spec("any")); + EXPECT_EQUAL(if_similar.result_type(), ValueType::from_spec("tensor(x[])")); + EXPECT_EQUAL(if_subtype.result_type(), ValueType::from_spec("tensor(x[])")); + EXPECT_EQUAL(if_different.result_type(), ValueType::from_spec("tensor")); + EXPECT_EQUAL(if_different_types.result_type(), ValueType::from_spec("any")); + EXPECT_EQUAL(if_with_error.result_type(), ValueType::from_spec("error")); } TEST("require that push_children works") { diff --git a/eval/src/tests/eval/value_type/value_type_test.cpp b/eval/src/tests/eval/value_type/value_type_test.cpp index 06a0a4d679a..ffdc601932e 100644 --- a/eval/src/tests/eval/value_type/value_type_test.cpp +++ b/eval/src/tests/eval/value_type/value_type_test.cpp @@ -433,4 +433,34 @@ TEST("require that types can be concatenated") { EXPECT_EQUAL(ValueType::concat(vx_5, vy_7, "z"), cxyz_572); } +TEST("require that 'either' gives appropriate type") { + ValueType error = ValueType::error_type(); + ValueType any = ValueType::any_type(); + ValueType tensor = ValueType::tensor_type({}); + ValueType scalar = ValueType::double_type(); + ValueType vx_2 = ValueType::from_spec("tensor(x[2])"); + ValueType vx_m = ValueType::from_spec("tensor(x{})"); + ValueType vx_3 = ValueType::from_spec("tensor(x[3])"); + ValueType vx_any = ValueType::from_spec("tensor(x[])"); + ValueType vy_2 = ValueType::from_spec("tensor(y[2])"); + ValueType mxy_22 = ValueType::from_spec("tensor(x[2],y[2])"); + ValueType mxy_23 = ValueType::from_spec("tensor(x[2],y[3])"); + ValueType mxy_32 = ValueType::from_spec("tensor(x[3],y[2])"); + ValueType mxy_any2 = ValueType::from_spec("tensor(x[],y[2])"); + ValueType mxy_2any = ValueType::from_spec("tensor(x[2],y[])"); + + EXPECT_EQUAL(ValueType::either(vx_2, error), error); + EXPECT_EQUAL(ValueType::either(error, vx_2), error); + EXPECT_EQUAL(ValueType::either(vx_2, vx_2), vx_2); + EXPECT_EQUAL(ValueType::either(vx_2, scalar), any); + EXPECT_EQUAL(ValueType::either(scalar, vx_2), any); + EXPECT_EQUAL(ValueType::either(vx_2, mxy_22), tensor); + EXPECT_EQUAL(ValueType::either(tensor, vx_2), tensor); + EXPECT_EQUAL(ValueType::either(vx_2, vy_2), tensor); + EXPECT_EQUAL(ValueType::either(vx_2, vx_m), tensor); + EXPECT_EQUAL(ValueType::either(vx_2, vx_3), vx_any); + EXPECT_EQUAL(ValueType::either(mxy_22, mxy_23), mxy_2any); + EXPECT_EQUAL(ValueType::either(mxy_32, mxy_22), mxy_any2); +} + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp index 0cbc30667f0..c85c81b3752 100644 --- a/eval/src/vespa/eval/eval/node_types.cpp +++ b/eval/src/vespa/eval/eval/node_types.cpp @@ -94,15 +94,7 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser { void visit(const Neg &node) override { resolve_op1(node); } void visit(const Not &node) override { resolve_op1(node); } void visit(const If &node) override { - ValueType true_type = state.peek(1); - ValueType false_type = state.peek(0); - if (true_type == false_type) { - bind_type(true_type, node); - } else if (true_type.is_tensor() && false_type.is_tensor()) { - bind_type(ValueType::tensor_type({}), node); - } else { - bind_type(ValueType::any_type(), node); - } + bind_type(ValueType::either(state.peek(1), state.peek(0)), node); } void visit(const Error &node) override { bind_type(ValueType::error_type(), node); diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp index 1a60dd0e898..8427cc53a16 100644 --- a/eval/src/vespa/eval/eval/tensor_function.cpp +++ b/eval/src/vespa/eval/eval/tensor_function.cpp @@ -148,10 +148,7 @@ const Node &rename(const Node &child, const std::vector<vespalib::string> &from, } const Node &if_node(const Node &cond, const Node &true_child, const Node &false_child, Stash &stash) { - ValueType result_type = true_child.result_type(); - if (result_type != false_child.result_type()) { - result_type = ValueType::any_type(); - } + ValueType result_type = ValueType::either(true_child.result_type(), false_child.result_type()); return stash.create<If>(result_type, cond, true_child, false_child); } diff --git a/eval/src/vespa/eval/eval/value_type.cpp b/eval/src/vespa/eval/eval/value_type.cpp index 1c4973a78ca..1d49fe494b7 100644 --- a/eval/src/vespa/eval/eval/value_type.cpp +++ b/eval/src/vespa/eval/eval/value_type.cpp @@ -265,6 +265,40 @@ ValueType::concat(const ValueType &lhs, const ValueType &rhs, const vespalib::st return tensor_type(std::move(result.dimensions)); } +ValueType +ValueType::either(const ValueType &one, const ValueType &other) +{ + if (one.is_error() || other.is_error()) { + return error_type(); + } + if (one == other) { + return one; + } + if (!one.is_tensor() || !other.is_tensor()) { + return any_type(); + } + if (one.dimensions().size() != other.dimensions().size()) { + return tensor_type({}); + } + std::vector<Dimension> dims; + for (size_t i = 0; i < one.dimensions().size(); ++i) { + const Dimension &a = one.dimensions()[i]; + const Dimension &b = other.dimensions()[i]; + if (a.name != b.name) { + return tensor_type({}); + } + if (a.is_mapped() != b.is_mapped()) { + return tensor_type({}); + } + if (a.size == b.size) { + dims.push_back(a); + } else { + dims.emplace_back(a.name, 0); + } + } + return tensor_type(std::move(dims)); +} + std::ostream & operator<<(std::ostream &os, const ValueType &type) { return os << type.to_spec(); diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h index a4762acd4c0..564d6a6b84e 100644 --- a/eval/src/vespa/eval/eval/value_type.h +++ b/eval/src/vespa/eval/eval/value_type.h @@ -87,6 +87,7 @@ public: vespalib::string to_spec() const; static ValueType join(const ValueType &lhs, const ValueType &rhs); static ValueType concat(const ValueType &lhs, const ValueType &rhs, const vespalib::string &dimension); + static ValueType either(const ValueType &one, const ValueType &other); }; std::ostream &operator<<(std::ostream &os, const ValueType &type); |