diff options
author | Håvard Pettersen <havardpe@oath.com> | 2019-05-02 13:09:37 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2019-05-03 11:22:03 +0000 |
commit | 2bc26be1fd42205b293b723678c5b32cebe9a2a0 (patch) | |
tree | 0a8445c0888b66c48cc53e72eeb40ec7f0c0c942 /eval | |
parent | b915a61385e9e6d3b686fa949a76b763f89dbbe3 (diff) |
add float cases to node type test
also update tensor lambda parsing to support cell type
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/tests/eval/node_types/node_types_test.cpp | 16 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/function.cpp | 9 |
2 files changed, 21 insertions, 4 deletions
diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp index d5c0ed995f1..256c7b85f72 100644 --- a/eval/src/tests/eval/node_types/node_types_test.cpp +++ b/eval/src/tests/eval/node_types/node_types_test.cpp @@ -78,6 +78,7 @@ TEST("require that input parameters preserve their type") { TEST_DO(verify("double", "double")); TEST_DO(verify("tensor()", "double")); TEST_DO(verify("tensor(x{},y[10],z[5])", "tensor(x{},y[10],z[5])")); + TEST_DO(verify("tensor<float>(x{},y[10],z[5])", "tensor<float>(x{},y[10],z[5])")); } TEST("require that if resolves to the appropriate type") { @@ -88,6 +89,8 @@ TEST("require that if resolves to the appropriate type") { TEST_DO(verify("if(tensor(x[10]),1,2)", "double")); TEST_DO(verify("if(double,tensor(a{}),tensor(a{}))", "tensor(a{})")); TEST_DO(verify("if(double,tensor(a[2]),tensor(a[2]))", "tensor(a[2])")); + TEST_DO(verify("if(double,tensor<float>(a[2]),tensor<float>(a[2]))", "tensor<float>(a[2])")); + TEST_DO(verify("if(double,tensor(a[2]),tensor<float>(a[2]))", "error")); TEST_DO(verify("if(double,tensor(a[2]),tensor(a[3]))", "error")); TEST_DO(verify("if(double,tensor(a[2]),tensor(a{}))", "error")); TEST_DO(verify("if(double,tensor(a{}),tensor(b{}))", "error")); @@ -105,6 +108,9 @@ TEST("require that reduce resolves correct type") { TEST_DO(verify("reduce(tensor(x{},y{},z{}),sum,y,z,x)", "double")); TEST_DO(verify("reduce(tensor(x{},y{},z{}),sum,w)", "error")); TEST_DO(verify("reduce(tensor(x{}),sum,x)", "double")); + TEST_DO(verify("reduce(tensor<float>(x{},y{},z{}),sum,x,z)", "tensor<float>(y{})")); + TEST_DO(verify("reduce(tensor<float>(x{}),sum,x)", "double")); + TEST_DO(verify("reduce(tensor<float>(x{}),sum)", "double")); } TEST("require that rename resolves correct type") { @@ -119,6 +125,7 @@ TEST("require that rename resolves correct type") { TEST_DO(verify("rename(tensor(x{},y[1],z[5]),(x,y,z),(z,y,x))", "tensor(z{},y[1],x[5])")); TEST_DO(verify("rename(tensor(x{},y[1],z[5]),(x,z),(z,x))", "tensor(z{},y[1],x[5])")); TEST_DO(verify("rename(tensor(x{},y[1],z[5]),(x,y,z),(a,b,c))", "tensor(a{},b[1],c[5])")); + TEST_DO(verify("rename(tensor<float>(x{},y[1],z[5]),(x,y,z),(a,b,c))", "tensor<float>(a{},b[1],c[5])")); } vespalib::string strfmt(const char *pattern, const char *a) { @@ -133,6 +140,7 @@ void verify_op1(const char *pattern) { TEST_DO(verify(strfmt(pattern, "error"), "error")); TEST_DO(verify(strfmt(pattern, "double"), "double")); TEST_DO(verify(strfmt(pattern, "tensor(x{},y[10],z[1])"), "tensor(x{},y[10],z[1])")); + TEST_DO(verify(strfmt(pattern, "tensor<float>(x{},y[10],z[1])"), "tensor<float>(x{},y[10],z[1])")); } void verify_op2(const char *pattern) { @@ -150,6 +158,9 @@ void verify_op2(const char *pattern) { TEST_DO(verify(strfmt(pattern, "tensor(x[3])", "tensor(x[5])"), "error")); TEST_DO(verify(strfmt(pattern, "tensor(x[5])", "tensor(x[3])"), "error")); TEST_DO(verify(strfmt(pattern, "tensor(x{})", "tensor(x[5])"), "error")); + TEST_DO(verify(strfmt(pattern, "tensor<float>(x[5])", "tensor<float>(x[5])"), "tensor<float>(x[5])")); + TEST_DO(verify(strfmt(pattern, "tensor<float>(x[5])", "tensor(x[5])"), "tensor(x[5])")); + TEST_DO(verify(strfmt(pattern, "tensor<float>(x[5])", "double"), "tensor<float>(x[5])")); } TEST("require that various operations resolve appropriate type") { @@ -213,6 +224,8 @@ TEST("require that lambda tensor resolves correct type") { TEST_DO(verify("tensor(x[5])(1.0)", "tensor(x[5])", false)); TEST_DO(verify("tensor(x[5],y[10])(1.0)", "tensor(x[5],y[10])", false)); TEST_DO(verify("tensor(x[5],y[10],z[15])(1.0)", "tensor(x[5],y[10],z[15])", false)); + TEST_DO(verify("tensor<double>(x[5],y[10],z[15])(1.0)", "tensor(x[5],y[10],z[15])", false)); + TEST_DO(verify("tensor<float>(x[5],y[10],z[15])(1.0)", "tensor<float>(x[5],y[10],z[15])", false)); } TEST("require that tensor concat resolves correct type") { @@ -222,6 +235,9 @@ TEST("require that tensor concat resolves correct type") { TEST_DO(verify("concat(tensor(x[2]),tensor(x[3]),y)", "error")); TEST_DO(verify("concat(tensor(x[2]),tensor(x{}),x)", "error")); TEST_DO(verify("concat(tensor(x[2]),tensor(y{}),x)", "tensor(x[3],y{})")); + TEST_DO(verify("concat(tensor<float>(x[2]),tensor<float>(x[3]),x)", "tensor<float>(x[5])")); + TEST_DO(verify("concat(tensor<float>(x[2]),tensor(x[3]),x)", "tensor(x[5])")); + TEST_DO(verify("concat(tensor<float>(x[2]),double,x)", "tensor<float>(x[3])")); } TEST("require that double only expressions can be detected") { diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp index c4f91067260..8b9e440318d 100644 --- a/eval/src/vespa/eval/eval/function.cpp +++ b/eval/src/vespa/eval/eval/function.cpp @@ -558,7 +558,7 @@ void parse_tensor_rename(ParseContext &ctx) { } void parse_tensor_lambda(ParseContext &ctx) { - vespalib::string type_spec("tensor("); + vespalib::string type_spec("tensor"); while(!ctx.eos() && (ctx.get() != ')')) { type_spec.push_back(ctx.get()); ctx.next(); @@ -576,6 +576,7 @@ void parse_tensor_lambda(ParseContext &ctx) { ctx.skip_spaces(); ctx.eat('('); parse_expression(ctx); + ctx.eat(')'); ctx.pop_resolve_context(); Function lambda(ctx.pop_expression(), std::move(param_names)); ctx.push_expression(std::make_unique<nodes::TensorLambda>(std::move(type), std::move(lambda))); @@ -611,8 +612,6 @@ bool try_parse_call(ParseContext &ctx, const vespalib::string &name) { parse_tensor_reduce(ctx); } else if (name == "rename") { parse_tensor_rename(ctx); - } else if (name == "tensor") { - parse_tensor_lambda(ctx); } else if (name == "concat") { parse_tensor_concat(ctx); } else { @@ -634,7 +633,9 @@ size_t parse_symbol(ParseContext &ctx, vespalib::string &name, ParseContext::Inp void parse_symbol_or_call(ParseContext &ctx) { ParseContext::InputMark before_name = ctx.get_input_mark(); vespalib::string name = get_ident(ctx, true); - if (!try_parse_call(ctx, name)) { + if (name == "tensor") { + parse_tensor_lambda(ctx); + } else if (!try_parse_call(ctx, name)) { size_t id = parse_symbol(ctx, name, before_name); if (name.empty()) { ctx.fail("missing value"); |