summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2019-05-02 13:09:37 +0000
committerHåvard Pettersen <havardpe@oath.com>2019-05-03 11:22:03 +0000
commit2bc26be1fd42205b293b723678c5b32cebe9a2a0 (patch)
tree0a8445c0888b66c48cc53e72eeb40ec7f0c0c942 /eval
parentb915a61385e9e6d3b686fa949a76b763f89dbbe3 (diff)
add float cases to node type test
also update tensor lambda parsing to support cell type
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/eval/node_types/node_types_test.cpp16
-rw-r--r--eval/src/vespa/eval/eval/function.cpp9
2 files changed, 21 insertions, 4 deletions
diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp
index d5c0ed995f1..256c7b85f72 100644
--- a/eval/src/tests/eval/node_types/node_types_test.cpp
+++ b/eval/src/tests/eval/node_types/node_types_test.cpp
@@ -78,6 +78,7 @@ TEST("require that input parameters preserve their type") {
TEST_DO(verify("double", "double"));
TEST_DO(verify("tensor()", "double"));
TEST_DO(verify("tensor(x{},y[10],z[5])", "tensor(x{},y[10],z[5])"));
+ TEST_DO(verify("tensor<float>(x{},y[10],z[5])", "tensor<float>(x{},y[10],z[5])"));
}
TEST("require that if resolves to the appropriate type") {
@@ -88,6 +89,8 @@ TEST("require that if resolves to the appropriate type") {
TEST_DO(verify("if(tensor(x[10]),1,2)", "double"));
TEST_DO(verify("if(double,tensor(a{}),tensor(a{}))", "tensor(a{})"));
TEST_DO(verify("if(double,tensor(a[2]),tensor(a[2]))", "tensor(a[2])"));
+ TEST_DO(verify("if(double,tensor<float>(a[2]),tensor<float>(a[2]))", "tensor<float>(a[2])"));
+ TEST_DO(verify("if(double,tensor(a[2]),tensor<float>(a[2]))", "error"));
TEST_DO(verify("if(double,tensor(a[2]),tensor(a[3]))", "error"));
TEST_DO(verify("if(double,tensor(a[2]),tensor(a{}))", "error"));
TEST_DO(verify("if(double,tensor(a{}),tensor(b{}))", "error"));
@@ -105,6 +108,9 @@ TEST("require that reduce resolves correct type") {
TEST_DO(verify("reduce(tensor(x{},y{},z{}),sum,y,z,x)", "double"));
TEST_DO(verify("reduce(tensor(x{},y{},z{}),sum,w)", "error"));
TEST_DO(verify("reduce(tensor(x{}),sum,x)", "double"));
+ TEST_DO(verify("reduce(tensor<float>(x{},y{},z{}),sum,x,z)", "tensor<float>(y{})"));
+ TEST_DO(verify("reduce(tensor<float>(x{}),sum,x)", "double"));
+ TEST_DO(verify("reduce(tensor<float>(x{}),sum)", "double"));
}
TEST("require that rename resolves correct type") {
@@ -119,6 +125,7 @@ TEST("require that rename resolves correct type") {
TEST_DO(verify("rename(tensor(x{},y[1],z[5]),(x,y,z),(z,y,x))", "tensor(z{},y[1],x[5])"));
TEST_DO(verify("rename(tensor(x{},y[1],z[5]),(x,z),(z,x))", "tensor(z{},y[1],x[5])"));
TEST_DO(verify("rename(tensor(x{},y[1],z[5]),(x,y,z),(a,b,c))", "tensor(a{},b[1],c[5])"));
+ TEST_DO(verify("rename(tensor<float>(x{},y[1],z[5]),(x,y,z),(a,b,c))", "tensor<float>(a{},b[1],c[5])"));
}
vespalib::string strfmt(const char *pattern, const char *a) {
@@ -133,6 +140,7 @@ void verify_op1(const char *pattern) {
TEST_DO(verify(strfmt(pattern, "error"), "error"));
TEST_DO(verify(strfmt(pattern, "double"), "double"));
TEST_DO(verify(strfmt(pattern, "tensor(x{},y[10],z[1])"), "tensor(x{},y[10],z[1])"));
+ TEST_DO(verify(strfmt(pattern, "tensor<float>(x{},y[10],z[1])"), "tensor<float>(x{},y[10],z[1])"));
}
void verify_op2(const char *pattern) {
@@ -150,6 +158,9 @@ void verify_op2(const char *pattern) {
TEST_DO(verify(strfmt(pattern, "tensor(x[3])", "tensor(x[5])"), "error"));
TEST_DO(verify(strfmt(pattern, "tensor(x[5])", "tensor(x[3])"), "error"));
TEST_DO(verify(strfmt(pattern, "tensor(x{})", "tensor(x[5])"), "error"));
+ TEST_DO(verify(strfmt(pattern, "tensor<float>(x[5])", "tensor<float>(x[5])"), "tensor<float>(x[5])"));
+ TEST_DO(verify(strfmt(pattern, "tensor<float>(x[5])", "tensor(x[5])"), "tensor(x[5])"));
+ TEST_DO(verify(strfmt(pattern, "tensor<float>(x[5])", "double"), "tensor<float>(x[5])"));
}
TEST("require that various operations resolve appropriate type") {
@@ -213,6 +224,8 @@ TEST("require that lambda tensor resolves correct type") {
TEST_DO(verify("tensor(x[5])(1.0)", "tensor(x[5])", false));
TEST_DO(verify("tensor(x[5],y[10])(1.0)", "tensor(x[5],y[10])", false));
TEST_DO(verify("tensor(x[5],y[10],z[15])(1.0)", "tensor(x[5],y[10],z[15])", false));
+ TEST_DO(verify("tensor<double>(x[5],y[10],z[15])(1.0)", "tensor(x[5],y[10],z[15])", false));
+ TEST_DO(verify("tensor<float>(x[5],y[10],z[15])(1.0)", "tensor<float>(x[5],y[10],z[15])", false));
}
TEST("require that tensor concat resolves correct type") {
@@ -222,6 +235,9 @@ TEST("require that tensor concat resolves correct type") {
TEST_DO(verify("concat(tensor(x[2]),tensor(x[3]),y)", "error"));
TEST_DO(verify("concat(tensor(x[2]),tensor(x{}),x)", "error"));
TEST_DO(verify("concat(tensor(x[2]),tensor(y{}),x)", "tensor(x[3],y{})"));
+ TEST_DO(verify("concat(tensor<float>(x[2]),tensor<float>(x[3]),x)", "tensor<float>(x[5])"));
+ TEST_DO(verify("concat(tensor<float>(x[2]),tensor(x[3]),x)", "tensor(x[5])"));
+ TEST_DO(verify("concat(tensor<float>(x[2]),double,x)", "tensor<float>(x[3])"));
}
TEST("require that double only expressions can be detected") {
diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp
index c4f91067260..8b9e440318d 100644
--- a/eval/src/vespa/eval/eval/function.cpp
+++ b/eval/src/vespa/eval/eval/function.cpp
@@ -558,7 +558,7 @@ void parse_tensor_rename(ParseContext &ctx) {
}
void parse_tensor_lambda(ParseContext &ctx) {
- vespalib::string type_spec("tensor(");
+ vespalib::string type_spec("tensor");
while(!ctx.eos() && (ctx.get() != ')')) {
type_spec.push_back(ctx.get());
ctx.next();
@@ -576,6 +576,7 @@ void parse_tensor_lambda(ParseContext &ctx) {
ctx.skip_spaces();
ctx.eat('(');
parse_expression(ctx);
+ ctx.eat(')');
ctx.pop_resolve_context();
Function lambda(ctx.pop_expression(), std::move(param_names));
ctx.push_expression(std::make_unique<nodes::TensorLambda>(std::move(type), std::move(lambda)));
@@ -611,8 +612,6 @@ bool try_parse_call(ParseContext &ctx, const vespalib::string &name) {
parse_tensor_reduce(ctx);
} else if (name == "rename") {
parse_tensor_rename(ctx);
- } else if (name == "tensor") {
- parse_tensor_lambda(ctx);
} else if (name == "concat") {
parse_tensor_concat(ctx);
} else {
@@ -634,7 +633,9 @@ size_t parse_symbol(ParseContext &ctx, vespalib::string &name, ParseContext::Inp
void parse_symbol_or_call(ParseContext &ctx) {
ParseContext::InputMark before_name = ctx.get_input_mark();
vespalib::string name = get_ident(ctx, true);
- if (!try_parse_call(ctx, name)) {
+ if (name == "tensor") {
+ parse_tensor_lambda(ctx);
+ } else if (!try_parse_call(ctx, name)) {
size_t id = parse_symbol(ctx, name, before_name);
if (name.empty()) {
ctx.fail("missing value");