diff options
author | Håvard Pettersen <havardpe@oath.com> | 2017-11-03 11:44:25 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2017-11-03 12:35:31 +0000 |
commit | 228a4f089d2431bf2012982bce9e093f2df2dead (patch) | |
tree | 01119a93d453585edf6c88017c24abff534b0b21 /eval/src | |
parent | 5a69cb546ad5661001c89eecbf13c8b41b57019c (diff) |
handle 'in' operator as custom (tensor) map operation
free arrays no longer allowed
restrict set members to be numbers or strings
auto-unbox negative numbers in AST
Diffstat (limited to 'eval/src')
-rw-r--r-- | eval/src/tests/eval/compiled_function/compiled_function_test.cpp | 15 | ||||
-rw-r--r-- | eval/src/tests/eval/function/function_test.cpp | 101 | ||||
-rw-r--r-- | eval/src/tests/eval/gbdt/gbdt_test.cpp | 32 | ||||
-rw-r--r-- | eval/src/tests/eval/node_types/node_types_test.cpp | 21 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/basic_nodes.cpp | 4 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/basic_nodes.h | 47 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/function.cpp | 64 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/gbdt.cpp | 6 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/interpreted_function.cpp | 50 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/key_gen.cpp | 8 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp | 64 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/node_types.cpp | 7 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/node_visitor.h | 6 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/operator_nodes.cpp | 13 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/operator_nodes.h | 3 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/test/eval_spec.cpp | 73 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/vm_forest.cpp | 15 |
17 files changed, 224 insertions, 305 deletions
diff --git a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp index b887c6e45f9..0e9806d5381 100644 --- a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp +++ b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp @@ -128,16 +128,13 @@ TEST_FF("require that compiled evaluation passes all conformance tests", MyEvalT //----------------------------------------------------------------------------- TEST("require that large (plugin) set membership checks work") { - nodes::Array my_set; + auto my_in = std::make_unique<nodes::In>(std::make_unique<nodes::Symbol>(0)); for(size_t i = 1; i <= 100; ++i) { - my_set.add(nodes::Node_UP(new nodes::Number(i))); + my_in->add_entry(std::make_unique<nodes::Number>(i)); } - nodes::DumpContext dump_ctx({}); - vespalib::string expr = vespalib::make_string("if(a in %s,1,0)", - my_set.dump(dump_ctx).c_str()); - // fprintf(stderr, "expression: %s\n", expr.c_str()); - CompiledFunction cf(Function::parse(expr), PassParams::SEPARATE); - CompiledFunction arr_cf(Function::parse(expr), PassParams::ARRAY); + Function my_fun(std::move(my_in), {"a"}); + CompiledFunction cf(my_fun, PassParams::SEPARATE); + CompiledFunction arr_cf(my_fun, PassParams::ARRAY); auto fun = cf.get_function<1>(); auto arr_fun = arr_cf.get_function(); for (double value = 0.5; value <= 100.5; value += 0.5) { @@ -146,7 +143,7 @@ TEST("require that large (plugin) set membership checks work") { EXPECT_EQUAL(1.0, arr_fun(&value)); } else { EXPECT_EQUAL(0.0, fun(value)); - EXPECT_EQUAL(0.0, arr_fun(&value)); + EXPECT_EQUAL(0.0, arr_fun(&value)); } } } diff --git a/eval/src/tests/eval/function/function_test.cpp b/eval/src/tests/eval/function/function_test.cpp index 75a6df41b50..6c3839b6cc9 100644 --- a/eval/src/tests/eval/function/function_test.cpp +++ b/eval/src/tests/eval/function/function_test.cpp @@ -81,6 +81,11 @@ bool verify_string(const vespalib::string &str, const vespalib::string &expr) { return ok; } +void verify_error(const vespalib::string &expr, const vespalib::string &expected_error) { + Function function = Function::parse(params, expr); + EXPECT_TRUE(function.has_error()); + EXPECT_EQUAL(expected_error, function.get_error()); +} TEST("require that scientific numbers can be parsed") { EXPECT_EQUAL(1.0, as_number(Function::parse(params, "1"))); @@ -163,18 +168,16 @@ TEST("require that strings are parsed and dumped correctly") { } } -TEST("require that arrays can be parsed") { - EXPECT_EQUAL("[]", Function::parse(params, "[]").dump()); - EXPECT_EQUAL("[1,2,3]", Function::parse(params, "[1,2,3]").dump()); - EXPECT_EQUAL("[1,2,3]", Function::parse(params, "[ 1 , 2 , 3 ]").dump()); - EXPECT_EQUAL("[[x],[x,y],[1,2,[z,w]]]", Function::parse(params, "[[x],[x,y],[1,2,[z,w]]]").dump()); - EXPECT_EQUAL("[(x+1),(y-[3,7]),z,[]]", Function::parse(params, "[x+1,y-[3,7],z,[]]").dump()); +TEST("require that free arrays cannot be parsed") { + verify_error("[1,2,3]", "[]...[missing value]...[[1,2,3]]"); } TEST("require that negative values can be parsed") { - EXPECT_EQUAL("(-1)", Function::parse(params, "-1").dump()); - EXPECT_EQUAL("(-2.5)", Function::parse(params, "-2.5").dump()); - EXPECT_EQUAL("(-100)", Function::parse(params, "-100").dump()); + EXPECT_EQUAL("-1", Function::parse(params, "-1").dump()); + EXPECT_EQUAL("1", Function::parse(params, "--1").dump()); + EXPECT_EQUAL("-1", Function::parse(params, " ( - ( - ( - ( (1) ) ) ) )").dump()); + EXPECT_EQUAL("-2.5", Function::parse(params, "-2.5").dump()); + EXPECT_EQUAL("-100", Function::parse(params, "-100").dump()); } TEST("require that negative symbols can be parsed") { @@ -206,7 +209,7 @@ TEST("require that operators have appropriate binding order") { verify_operator_binding_order({ { Operator::Order::RIGHT, { "^" } }, { Operator::Order::LEFT, { "*", "/", "%" } }, { Operator::Order::LEFT, { "+", "-" } }, - { Operator::Order::LEFT, { "==", "!=", "~=", "<", "<=", ">", ">=", "in" } }, + { Operator::Order::LEFT, { "==", "!=", "~=", "<", "<=", ">", ">=" } }, { Operator::Order::LEFT, { "&&" } }, { Operator::Order::LEFT, { "||" } } }); } @@ -248,10 +251,31 @@ TEST("require that operators can not bind out of parenthesis") { } TEST("require that set membership constructs can be parsed") { - EXPECT_EQUAL("(x in [y,z,w])", Function::parse(params, "x in [y,z,w]").dump()); - EXPECT_EQUAL("(x in [y,z,w])", Function::parse(params, "x in[y,z,w]").dump()); - EXPECT_EQUAL("(x in [y,z,w])", Function::parse(params, "(x)in[y,z,w]").dump()); - EXPECT_EQUAL("((x+1) in [y,z,(w-1)])", Function::parse(params, "(x+1)in[y,z,(w-1)]").dump()); + EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "x in [1,2,3]").dump()); + EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "x in [ 1 , 2 , 3 ] ").dump()); + EXPECT_EQUAL("(x in [-1,-2,-3])", Function::parse(params, "x in [-1,-2,-3]").dump()); + EXPECT_EQUAL("(x in [-1,-2,-3])", Function::parse(params, "x in [ - 1 , - 2 , - 3 ]").dump()); + EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "x in[1,2,3]").dump()); + EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "(x)in[1,2,3]").dump()); + EXPECT_EQUAL("(x in [\"a\",2,\"c\"])", Function::parse(params, "x in [\"a\",2,\"c\"]").dump()); +} + +TEST("require that set membership entries must be array of strings/numbers") { + verify_error("x in 1", "[x in ]...[expected '[', but got '1']...[1]"); + verify_error("x in ([1])", "[x in ]...[expected '[', but got '(']...[([1])]"); + verify_error("x in [y]", "[x in [y]...[invalid entry for 'in' operator]...[]]"); + verify_error("x in [!1]", "[x in [!1]...[invalid entry for 'in' operator]...[]]"); + verify_error("x in [1+2]", "[x in [1]...[expected ',', but got '+']...[+2]]"); + verify_error("x in [-\"foo\"]", "[x in [-\"foo\"]...[invalid entry for 'in' operator]...[]]"); +} + +TEST("require that set membership binds to the next value") { + EXPECT_EQUAL("((x in [1,2,3])^2)", Function::parse(params, "x in [1,2,3]^2").dump()); +} + +TEST("require that set membership binds to the left with appropriate precedence") { + EXPECT_EQUAL("((x<y) in [1,2,3])", Function::parse(params, "x < y in [1,2,3]").dump()); + EXPECT_EQUAL("(x&&(y in [1,2,3]))", Function::parse(params, "x && y in [1,2,3]").dump()); } TEST("require that function calls can be parsed") { @@ -309,22 +333,12 @@ TEST("require that leaf nodes have no children") { EXPECT_EQUAL(0u, Function::parse("\"abc\"").root().num_children()); } -TEST("require that Array children can be accessed") { - Function f = Function::parse("[1,2,3]"); - const Node &root = f.root(); - EXPECT_TRUE(!root.is_leaf()); - ASSERT_EQUAL(3u, root.num_children()); - EXPECT_EQUAL(1.0, root.get_child(0).get_const_value()); - EXPECT_EQUAL(2.0, root.get_child(1).get_const_value()); - EXPECT_EQUAL(3.0, root.get_child(2).get_const_value()); -} - TEST("require that Neg child can be accessed") { - Function f = Function::parse("-1"); + Function f = Function::parse("-x"); const Node &root = f.root(); EXPECT_TRUE(!root.is_leaf()); ASSERT_EQUAL(1u, root.num_children()); - EXPECT_EQUAL(1.0, root.get_child(0).get_const_value()); + EXPECT_TRUE(root.get_child(0).is_param()); } TEST("require that Not child can be accessed") { @@ -386,7 +400,7 @@ TEST("require that children can be detached") { EXPECT_EQUAL(1u, detach_from_root("-a")); EXPECT_EQUAL(1u, detach_from_root("!a")); EXPECT_EQUAL(3u, detach_from_root("if(1,2,3)")); - EXPECT_EQUAL(5u, detach_from_root("[1,2,3,4,5]")); + EXPECT_EQUAL(1u, detach_from_root("a in [1,2,3,4,5]")); EXPECT_EQUAL(2u, detach_from_root("a+b")); EXPECT_EQUAL(1u, detach_from_root("isNan(a)")); EXPECT_EQUAL(2u, detach_from_root("max(a,b)")); @@ -456,7 +470,7 @@ TEST("require that traversal works as expected") { EXPECT_TRUE(verify_expression_traversal("1")); EXPECT_TRUE(verify_expression_traversal("1+2")); EXPECT_TRUE(verify_expression_traversal("1+2*3-4/5")); - EXPECT_TRUE(verify_expression_traversal("if(x,1+2*3,[a,b,c]/5)")); + EXPECT_TRUE(verify_expression_traversal("if(x,1+2*3,if(a,b,c)/5)")); } //----------------------------------------------------------------------------- @@ -492,14 +506,6 @@ TEST("require that string is const") { EXPECT_TRUE(Function::parse("\"x\"").root().is_const()); } -TEST("require that array is const if all elements are const") { - EXPECT_TRUE(Function::parse("[1,2,3]").root().is_const()); - EXPECT_TRUE(!Function::parse("[x,2,3]").root().is_const()); - EXPECT_TRUE(!Function::parse("[1,y,3]").root().is_const()); - EXPECT_TRUE(!Function::parse("[1,2,z]").root().is_const()); - EXPECT_TRUE(!Function::parse("[x,y,z]").root().is_const()); -} - TEST("require that neg is const if sub-expression is const") { EXPECT_TRUE(Function::parse("-123").root().is_const()); EXPECT_TRUE(!Function::parse("-x").root().is_const()); @@ -517,11 +523,11 @@ TEST("require that operators are cost if both children are const") { EXPECT_TRUE(Function::parse("1+2").root().is_const()); } -TEST("require that set membership is const only if array elements are const") { +TEST("require that set membership is never tagged as const (NB: avoids jit recursion)") { EXPECT_TRUE(!Function::parse("x in [x,y,z]").root().is_const()); EXPECT_TRUE(!Function::parse("1 in [x,y,z]").root().is_const()); EXPECT_TRUE(!Function::parse("1 in [1,y,z]").root().is_const()); - EXPECT_TRUE(Function::parse("1 in [1,2,3]").root().is_const()); + EXPECT_TRUE(!Function::parse("1 in [1,2,3]").root().is_const()); } TEST("require that calls are cost if all parameters are const") { @@ -554,10 +560,8 @@ TEST("require that feature in set of constants is tree if children are trees or EXPECT_TRUE(Function::parse("if (foo in [1, 2], if(bar < 3, 4, 5), 6)").root().is_tree()); EXPECT_TRUE(Function::parse("if (foo in [1, 2], if(bar < 3, 4, 5), if(baz < 6, 7, 8))").root().is_tree()); EXPECT_TRUE(Function::parse("if (foo in [1, 2], 3, if(baz < 4, 5, 6))").root().is_tree()); - EXPECT_TRUE(Function::parse("if (foo in [min(1,2), max(1,2)], 3, 4)").root().is_tree()); + EXPECT_TRUE(Function::parse("if (foo in [1, 2], min(1,3), max(1,4))").root().is_tree()); EXPECT_TRUE(!Function::parse("if (1 in [1, 2], 3, 4)").root().is_tree()); - EXPECT_TRUE(!Function::parse("if (1 in [foo, 2], 3, 4)").root().is_tree()); - EXPECT_TRUE(!Function::parse("if (foo in [bar, 2], 3, 4)").root().is_tree()); } TEST("require that sums of trees and forests are forests") { @@ -671,14 +675,17 @@ TEST("require that unknown function that is valid parameter works as expected wi EXPECT_EQUAL("[z(x)]...[unknown symbol: 'z(x)']...[+y]", Function::parse(params, "z(x)+y", MySymbolExtractor({'(', ')'})).dump()); } -//----------------------------------------------------------------------------- - -void verify_error(const vespalib::string &expr, const vespalib::string &expected_error) { - Function function = Function::parse(params, expr); - EXPECT_TRUE(function.has_error()); - EXPECT_EQUAL(expected_error, function.get_error()); +TEST("require that custom symbol extractor is not invoked for known function call") { + MySymbolExtractor extractor; + EXPECT_EQUAL(extractor.invoke_count, 0u); + EXPECT_EQUAL("[bogus]...[unknown symbol: 'bogus']...[(1,2)]", Function::parse(params, "bogus(1,2)", extractor).dump()); + EXPECT_EQUAL(extractor.invoke_count, 1u); + EXPECT_EQUAL("max(1,2)", Function::parse(params, "max(1,2)", extractor).dump()); + EXPECT_EQUAL(extractor.invoke_count, 1u); } +//----------------------------------------------------------------------------- + TEST("require that valid function does not report parse error") { Function function = Function::parse(params, "x + y"); EXPECT_TRUE(!function.has_error()); diff --git a/eval/src/tests/eval/gbdt/gbdt_test.cpp b/eval/src/tests/eval/gbdt/gbdt_test.cpp index f4f5970cabe..20969a1e3b4 100644 --- a/eval/src/tests/eval/gbdt/gbdt_test.cpp +++ b/eval/src/tests/eval/gbdt/gbdt_test.cpp @@ -31,14 +31,14 @@ TEST("require that tree stats can be calculated") { EXPECT_EQUAL(tree_size, TreeStats(Function::parse(Model().make_tree(tree_size)).root()).size); } - TreeStats stats1(Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))").root()); + TreeStats stats1(Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))").root()); EXPECT_EQUAL(3u, stats1.num_params); EXPECT_EQUAL(4u, stats1.size); EXPECT_EQUAL(1u, stats1.num_less_checks); EXPECT_EQUAL(2u, stats1.num_in_checks); EXPECT_EQUAL(3u, stats1.max_set_size); - TreeStats stats2(Function::parse("if((d in 1),10.0,if((e<1),20.0,30.0))").root()); + TreeStats stats2(Function::parse("if((d in [1]),10.0,if((e<1),20.0,30.0))").root()); EXPECT_EQUAL(2u, stats2.num_params); EXPECT_EQUAL(3u, stats2.size); EXPECT_EQUAL(1u, stats2.num_less_checks); @@ -61,9 +61,9 @@ TEST("require that trees can be extracted from forest") { } TEST("require that forest stats can be calculated") { - Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+" - "if((d in 1),10.0,if((e<1),20.0,30.0))+" - "if((d in 1),10.0,if((e<1),20.0,30.0))"); + Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+" + "if((d in [1]),10.0,if((e<1),20.0,30.0))+" + "if((d in [1]),10.0,if((e<1),20.0,30.0))"); std::vector<const Node *> trees = extract_trees(function.root()); ForestStats stats(trees); EXPECT_EQUAL(5u, stats.num_params); @@ -261,8 +261,8 @@ TEST("require that models with in checks are rejected by less only vm optimizer" } TEST("require that general VM tree optimizer works") { - Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+" - "if((d in 1),10.0,if((e<1),if((f<1),20.0,30.0),40.0))"); + Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+" + "if((d in [1]),10.0,if((e<1),if((f<1),20.0,30.0),40.0))"); CompiledFunction compiled_function(function, PassParams::ARRAY, general_vm_chain); EXPECT_EQUAL(1u, compiled_function.get_forests().size()); auto f = compiled_function.get_function(); @@ -324,17 +324,17 @@ TEST("require that forests evaluate to approximately the same for all evaluation //----------------------------------------------------------------------------- TEST("require that GDBT expressions can be detected") { - Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+" - "if((d in 1),10.0,if((e<1),20.0,30.0))+" - "if((d in 1),10.0,if((e<1),20.0,30.0))"); + Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+" + "if((d in [1]),10.0,if((e<1),20.0,30.0))+" + "if((d in [1]),10.0,if((e<1),20.0,30.0))"); EXPECT_TRUE(contains_gbdt(function.root(), 9)); EXPECT_TRUE(!contains_gbdt(function.root(), 10)); } TEST("require that wrapped GDBT expressions can be detected") { - Function function = Function::parse("10*(if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+" - "if((d in 1),10.0,if((e<1),20.0,30.0))+" - "if((d in 1),10.0,if((e<1),20.0,30.0)))"); + Function function = Function::parse("10*(if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+" + "if((d in [1]),10.0,if((e<1),20.0,30.0))+" + "if((d in [1]),10.0,if((e<1),20.0,30.0)))"); EXPECT_TRUE(contains_gbdt(function.root(), 9)); EXPECT_TRUE(!contains_gbdt(function.root(), 10)); } @@ -345,9 +345,9 @@ TEST("require that lazy parameters are not suggested for GBDT models") { } TEST("require that lazy parameters can be suggested for small GBDT models") { - Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+" - "if((d in 1),10.0,if((e<1),20.0,30.0))+" - "if((d in 1),10.0,if((e<1),20.0,30.0))"); + Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+" + "if((d in [1]),10.0,if((e<1),20.0,30.0))+" + "if((d in [1]),10.0,if((e<1),20.0,30.0))"); EXPECT_TRUE(CompiledFunction::should_use_lazy_params(function)); } diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp index e8009447793..0b6ea9e4d35 100644 --- a/eval/src/tests/eval/node_types/node_types_test.cpp +++ b/eval/src/tests/eval/node_types/node_types_test.cpp @@ -81,12 +81,6 @@ TEST("require that input parameters preserve their type") { TEST_DO(verify("tensor(x{},y[10],z[])", "tensor(x{},y[10],z[])")); } -TEST("require that arrays are double (size) unless they contain an error") { - TEST_DO(verify("[1,2,3]", "double")); - TEST_DO(verify("[any,tensor,double]", "double")); - TEST_DO(verify("[1,error,3]", "error")); -} - TEST("require that if resolves to the appropriate type") { TEST_DO(verify("if(error,1,2)", "error")); TEST_DO(verify("if(1,error,2)", "error")); @@ -108,17 +102,6 @@ TEST("require that if resolves to the appropriate type") { TEST_DO(verify("if(double,any,double)", "any")); } -TEST("require that set membership resolves to double unless error") { - TEST_DO(verify("1 in [1,2,3]", "double")); - TEST_DO(verify("1 in [tensor,tensor,tensor]", "double")); - TEST_DO(verify("1 in tensor", "double")); - TEST_DO(verify("tensor in 1", "double")); - TEST_DO(verify("tensor in [1,2,any]", "double")); - TEST_DO(verify("any in [1,tensor,any]", "double")); - TEST_DO(verify("error in [1,tensor,any]", "error")); - TEST_DO(verify("any in [tensor,error,any]", "error")); -} - TEST("require that reduce resolves correct type") { TEST_DO(verify("reduce(error,sum)", "error")); TEST_DO(verify("reduce(tensor,sum)", "double")); @@ -244,6 +227,10 @@ TEST("require that map resolves correct type") { TEST_DO(verify_op1("map(%s,f(x)(sin(x)))")); } +TEST("require that set membership resolves correct type") { + TEST_DO(verify_op1("%s in [1,2,3]")); +} + TEST("require that join resolves correct type") { TEST_DO(verify_op2("join(%s,%s,f(x,y)(x+y))")); } diff --git a/eval/src/vespa/eval/eval/basic_nodes.cpp b/eval/src/vespa/eval/eval/basic_nodes.cpp index 50db7370a66..a96d634e07a 100644 --- a/eval/src/vespa/eval/eval/basic_nodes.cpp +++ b/eval/src/vespa/eval/eval/basic_nodes.cpp @@ -60,7 +60,7 @@ Node::traverse(NodeTraverser &traverser) const void Number::accept(NodeVisitor &visitor) const { visitor.visit(*this); } void Symbol::accept(NodeVisitor &visitor) const { visitor.visit(*this); } void String::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void Array ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void In ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } void Neg ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } void Not ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } void If ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } @@ -124,7 +124,7 @@ If::If(Node_UP cond_in, Node_UP true_expr_in, Node_UP false_expr_in, double p_tr if (less) { _is_tree = (less->lhs().is_param() && less->rhs().is_const()); } else if (in) { - _is_tree = (in->lhs().is_param() && in->rhs().is_const()); + _is_tree = in->child().is_param(); } } } diff --git a/eval/src/vespa/eval/eval/basic_nodes.h b/eval/src/vespa/eval/eval/basic_nodes.h index 3f4c21be810..ebf65178b99 100644 --- a/eval/src/vespa/eval/eval/basic_nodes.h +++ b/eval/src/vespa/eval/eval/basic_nodes.h @@ -142,40 +142,39 @@ public: void accept(NodeVisitor &visitor) const override; }; -class Array : public Node { +class In : public Node { private: - std::vector<Node_UP> _nodes; - bool _is_const; + Node_UP _child; + std::vector<Node_UP> _entries; public: - Array() : _nodes(), _is_const(false) {} - bool is_const() const override { return _is_const; } - size_t size() const { return _nodes.size(); } - const Node &get(size_t i) const { return *_nodes[i]; } - size_t num_children() const override { return size(); } - const Node &get_child(size_t idx) const override { return get(idx); } - void detach_children(NodeHandler &handler) override { - for (size_t i = 0; i < _nodes.size(); ++i) { - handler.handle(std::move(_nodes[i])); - } - _nodes.clear(); + In(Node_UP child) : _child(std::move(child)), _entries() {} + void add_entry(Node_UP entry) { + assert(entry->is_const()); + _entries.push_back(std::move(entry)); } - void add(Node_UP node) { - if (_nodes.empty()) { - _is_const = node->is_const(); - } else { - _is_const = (_is_const && node->is_const()); - } - _nodes.push_back(std::move(node)); + size_t num_entries() const { return _entries.size(); } + const Node &get_entry(size_t idx) const { return *_entries[idx]; } + const Node &child() const { return *_child; } + size_t num_children() const override { return _child ? 1 : 0; } + const Node &get_child(size_t idx) const override { + (void) idx; + assert(idx == 0); + return child(); + } + void detach_children(NodeHandler &handler) override { + handler.handle(std::move(_child)); } vespalib::string dump(DumpContext &ctx) const override { vespalib::string str; - str += "["; + str += "("; + str += _child->dump(ctx); + str += " in ["; CommaTracker node_list; - for (const auto &node: _nodes) { + for (const auto &node: _entries) { node_list.maybe_comma(str); str += node->dump(ctx); } - str += "]"; + str += "])"; return str; } void accept(NodeVisitor &visitor) const override; diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp index 0843baa700c..c4f91067260 100644 --- a/eval/src/vespa/eval/eval/function.cpp +++ b/eval/src/vespa/eval/eval/function.cpp @@ -281,12 +281,15 @@ public: size_t operator_mark() const { return _operator_mark; } void operator_mark(size_t mark) { _operator_mark = mark; } - void push_operator(Operator_UP node) { + void apply_until(const nodes::Operator &op) { while ((_operator_stack.size() > _operator_mark) && - (_operator_stack.back()->do_before(*node))) + (_operator_stack.back()->do_before(op))) { apply_operator(); } + } + void push_operator(Operator_UP node) { + apply_until(*node); _operator_stack.push_back(std::move(node)); } Operator_UP pop_operator() { @@ -299,6 +302,7 @@ public: //----------------------------------------------------------------------------- +void parse_value(ParseContext &ctx); void parse_expression(ParseContext &ctx); int unhex(char c) { @@ -642,8 +646,11 @@ void parse_symbol_or_call(ParseContext &ctx) { } } -void parse_array(ParseContext &ctx) { - std::unique_ptr<nodes::Array> array(new nodes::Array()); +void parse_in(ParseContext &ctx) +{ + ctx.apply_until(nodes::Less()); + auto in = std::make_unique<nodes::In>(ctx.pop_expression()); + ctx.skip_spaces(); ctx.eat('['); ctx.skip_spaces(); size_t size = 0; @@ -651,11 +658,19 @@ void parse_array(ParseContext &ctx) { if (++size > 1) { ctx.eat(','); } - parse_expression(ctx); - array->add(ctx.pop_expression()); + parse_value(ctx); + ctx.skip_spaces(); + auto entry = ctx.pop_expression(); + auto num = nodes::as<nodes::Number>(*entry); + auto str = nodes::as<nodes::String>(*entry); + if (num || str) { + in->add_entry(std::move(entry)); + } else { + ctx.fail("invalid entry for 'in' operator"); + } } ctx.eat(']'); - ctx.push_expression(std::move(array)); + ctx.push_expression(std::move(in)); } void parse_value(ParseContext &ctx) { @@ -663,7 +678,13 @@ void parse_value(ParseContext &ctx) { if (ctx.get() == '-') { ctx.next(); parse_value(ctx); - ctx.push_expression(Node_UP(new nodes::Neg(ctx.pop_expression()))); + auto entry = ctx.pop_expression(); + auto num = nodes::as<nodes::Number>(*entry); + if (num) { + ctx.push_expression(std::make_unique<nodes::Number>(-num->value())); + } else { + ctx.push_expression(std::make_unique<nodes::Neg>(std::move(entry))); + } } else if (ctx.get() == '!') { ctx.next(); parse_value(ctx); @@ -672,8 +693,6 @@ void parse_value(ParseContext &ctx) { ctx.next(); parse_expression(ctx); ctx.eat(')'); - } else if (ctx.get() == '[') { - parse_array(ctx); } else if (ctx.get() == '"') { parse_string(ctx); } else if (isdigit(ctx.get())) { @@ -683,7 +702,8 @@ void parse_value(ParseContext &ctx) { } } -void parse_operator(ParseContext &ctx) { +bool parse_operator(ParseContext &ctx) { + bool expect_value = true; ctx.skip_spaces(); vespalib::string &str = ctx.peek(ctx.scratch(), nodes::OperatorRepo::instance().max_size()); Operator_UP op = nodes::OperatorRepo::instance().create(str); @@ -691,24 +711,38 @@ void parse_operator(ParseContext &ctx) { ctx.push_operator(std::move(op)); ctx.skip(str.size()); } else { - ctx.fail(make_string("invalid operator: '%c'", ctx.get())); + vespalib::string ident = get_ident(ctx, true); + if (ident == "in") { + parse_in(ctx); + expect_value = false; + } else { + if (ident.empty()) { + ctx.fail(make_string("invalid operator: '%c'", ctx.get())); + } else { + ctx.fail(make_string("invalid operator: '%s'", ident.c_str())); + } + } } + return expect_value; } void parse_expression(ParseContext &ctx) { size_t old_mark = ctx.operator_mark(); ctx.operator_mark(ctx.num_operators()); + bool expect_value = true; for (;;) { - parse_value(ctx); + if (expect_value) { + parse_value(ctx); + } ctx.skip_spaces(); - if (ctx.eos() || ctx.get() == ')' || ctx.get() == ',' || ctx.get() == ']') { + if (ctx.eos() || ctx.get() == ')' || ctx.get() == ',') { while (ctx.num_operators() > ctx.operator_mark()) { ctx.apply_operator(); } ctx.operator_mark(old_mark); return; } - parse_operator(ctx); + expect_value = parse_operator(ctx); } } diff --git a/eval/src/vespa/eval/eval/gbdt.cpp b/eval/src/vespa/eval/eval/gbdt.cpp index cffd8bdb0c3..0edc42070ba 100644 --- a/eval/src/vespa/eval/eval/gbdt.cpp +++ b/eval/src/vespa/eval/eval/gbdt.cpp @@ -70,13 +70,11 @@ TreeStats::traverse(const nodes::Node &node, size_t depth, size_t &sum_path) { ++num_less_checks; } else { assert(in); - auto symbol = nodes::as<nodes::Symbol>(in->lhs()); + auto symbol = nodes::as<nodes::Symbol>(in->child()); assert(symbol); num_params = std::max(num_params, size_t(symbol->id() + 1)); ++num_in_checks; - auto array = nodes::as<nodes::Array>(in->rhs()); - size_t array_size = (array) ? array->size() : 1; - max_set_size = std::max(max_set_size, array_size); + max_set_size = std::max(max_set_size, in->num_entries()); } return 1.0 + (p_true * true_path) + ((1.0 - p_true) * false_path); } else { diff --git a/eval/src/vespa/eval/eval/interpreted_function.cpp b/eval/src/vespa/eval/eval/interpreted_function.cpp index caf71fcb68b..f99c4ace2dd 100644 --- a/eval/src/vespa/eval/eval/interpreted_function.cpp +++ b/eval/src/vespa/eval/eval/interpreted_function.cpp @@ -65,24 +65,6 @@ void op_skip_if_false(State &state, uint64_t param) { //----------------------------------------------------------------------------- -// compare lhs with a set member, short-circuit if found -void op_check_member(State &state, uint64_t param) { - if (state.peek(1).equal(state.peek(0))) { - state.replace(2, state.stash.create<DoubleValue>(1.0)); - state.program_offset += param; - } else { - state.stack.pop_back(); - } -} - -// set member not found, replace lhs with false -void op_not_member(State &state, uint64_t) { - state.stack.pop_back(); - state.stack.push_back(state.stash.create<DoubleValue>(0.0)); -} - -//----------------------------------------------------------------------------- - void op_double_map(State &state, uint64_t param) { state.replace(1, state.stash.create<DoubleValue>(to_map_fun(param)(state.peek(0).as_double()))); } @@ -252,8 +234,14 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser { void visit(const String &node) override { make_const_op(node, stash.create<DoubleValue>(node.hash())); } - void visit(const Array &node) override { - make_const_op(node, stash.create<DoubleValue>(node.size())); + void visit(const In &node) override { + auto my_in = std::make_unique<In>(std::make_unique<Symbol>(0)); + for (size_t i = 0; i < node.num_entries(); ++i) { + my_in->add_entry(std::make_unique<Number>(node.get_entry(i).get_const_value())); + } + Function my_fun(std::move(my_in), {"x"}); + const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(my_fun, PassParams::SEPARATE)); + make_map_op(node, token.get()->get().get_function<1>()); } void visit(const Neg &node) override { make_map_op(node, operation::Neg::f); @@ -367,26 +355,6 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser { void visit(const GreaterEqual &node) override { make_join_op(node, operation::GreaterEqual::f); } - void visit(const In &node) override { - std::vector<size_t> checks; - node.lhs().traverse(*this); - auto array = as<Array>(node.rhs()); - if (array) { - for (size_t i = 0; i < array->size(); ++i) { - array->get(i).traverse(*this); - checks.push_back(program.size()); - program.emplace_back(op_check_member); - } - } else { - node.rhs().traverse(*this); - checks.push_back(program.size()); - program.emplace_back(op_check_member); - } - for (size_t i = 0; i < checks.size(); ++i) { - program[checks[i]].update_param(program.size() - checks[i]); - } - program.emplace_back(op_not_member); - } void visit(const And &node) override { make_join_op(node, operation::And::f); } @@ -472,7 +440,7 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser { //------------------------------------------------------------------------- bool open(const Node &node) override { - if (check_type<Array, If, In>(node)) { + if (check_type<If>(node)) { node.accept(*this); return false; } diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp index a745023fcf4..e0494e1fe11 100644 --- a/eval/src/vespa/eval/eval/key_gen.cpp +++ b/eval/src/vespa/eval/eval/key_gen.cpp @@ -25,7 +25,12 @@ struct KeyGen : public NodeVisitor, public NodeTraverser { void visit(const Number &node) override { add_byte( 1); add_double(node.value()); } void visit(const Symbol &node) override { add_byte( 2); add_int(node.id()); } void visit(const String &node) override { add_byte( 3); add_hash(node.hash()); } - void visit(const Array &node) override { add_byte( 4); add_size(node.size()); } + void visit(const In &node) override { add_byte( 4); + add_size(node.num_entries()); + for (size_t i = 0; i < node.num_entries(); ++i) { + add_double(node.get_entry(i).get_const_value()); + } + } void visit(const Neg &) override { add_byte( 5); } void visit(const Not &) override { add_byte( 6); } void visit(const If &node) override { add_byte( 7); add_double(node.p_true()); } @@ -49,7 +54,6 @@ struct KeyGen : public NodeVisitor, public NodeTraverser { void visit(const LessEqual &) override { add_byte(30); } void visit(const Greater &) override { add_byte(31); } void visit(const GreaterEqual &) override { add_byte(32); } - void visit(const In &) override { add_byte(33); } void visit(const And &) override { add_byte(34); } void visit(const Or &) override { add_byte(35); } void visit(const Cos &) override { add_byte(36); } diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp index 303833e8e6c..9355cf7a4e4 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp @@ -56,9 +56,9 @@ namespace { struct SetMemberHash : PluginState { vespalib::hash_set<double> members; - explicit SetMemberHash(const Array &array) : members(array.size() * 3) { - for (size_t i = 0; i < array.size(); ++i) { - members.insert(array.get(i).get_const_value()); + explicit SetMemberHash(const In &in) : members(in.num_entries() * 3) { + for (size_t i = 0; i < in.num_entries(); ++i) { + members.insert(in.get_entry(i).get_const_value()); } } static bool check_membership(const PluginState *state, double value) { @@ -252,7 +252,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { inside_forest = true; forest_end = &node; } - if (check_type<Array, If, In>(node)) { + if (check_type<If>(node)) { node.accept(*this); return false; } @@ -355,9 +355,27 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { void visit(const String &item) override { push_double(item.hash()); } - void visit(const Array &item) override { - // NB: visit not open - push_double(item.size()); + void visit(const In &item) override { + llvm::Value *lhs = pop_double(); + if (item.num_entries() > 8) { + // build call to hash lookup + plugin_state.emplace_back(new SetMemberHash(item)); + void *call_ptr = (void *) SetMemberHash::check_membership; + PluginState *state = plugin_state.back().get(); + llvm::PointerType *funptr_t = make_check_membership_funptr_t(); + llvm::Value *call_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)call_ptr), funptr_t, "inject_call_addr"); + llvm::Value *ctx = builder.CreateIntToPtr(builder.getInt64((uint64_t)state), builder.getVoidTy()->getPointerTo(), "inject_ctx"); + push(builder.CreateCall(call_fun, {ctx, lhs}, "call_check_membership")); + } else { + // build explicit code to check all set members + llvm::Value *found = builder.getFalse(); + for (size_t i = 0; i < item.num_entries(); ++i) { + llvm::Value *elem = llvm::ConstantFP::get(builder.getDoubleTy(), item.get_entry(i).get_const_value()); + llvm::Value *elem_eq = builder.CreateFCmpOEQ(lhs, elem, "elem_eq"); + found = builder.CreateOr(found, elem_eq, "found"); + } + push(found); + } } void visit(const Neg &) override { llvm::Value *child = pop_double(); @@ -480,38 +498,6 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { llvm::Value *a = pop_double(); push(builder.CreateFCmpOGE(a, b, "cmp_ge_res")); } - void visit(const In &item) override { - // NB: visit not open - item.lhs().traverse(*this); // NB: recursion - llvm::Value *lhs = pop_double(); - auto array = as<Array>(item.rhs()); - if (array) { - if (array->is_const() && array->size() > 8) { - // build call to hash lookup - plugin_state.emplace_back(new SetMemberHash(*array)); - void *call_ptr = (void *) SetMemberHash::check_membership; - PluginState *state = plugin_state.back().get(); - llvm::PointerType *funptr_t = make_check_membership_funptr_t(); - llvm::Value *call_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)call_ptr), funptr_t, "inject_call_addr"); - llvm::Value *ctx = builder.CreateIntToPtr(builder.getInt64((uint64_t)state), builder.getVoidTy()->getPointerTo(), "inject_ctx"); - push(builder.CreateCall(call_fun, {ctx, lhs}, "call_check_membership")); - } else { - // build explicit code to check all set members - llvm::Value *found = builder.getFalse(); - for (size_t i = 0; i < array->size(); ++i) { - array->get(i).traverse(*this); // NB: recursion - llvm::Value *elem = pop_double(); - llvm::Value *elem_eq = builder.CreateFCmpOEQ(lhs, elem, "elem_eq"); - found = builder.CreateOr(found, elem_eq, "found"); - } - push(found); - } - } else { - item.rhs().traverse(*this); // NB: recursion - llvm::Value *rhs = pop_double(); - push(builder.CreateFCmpOEQ(lhs, rhs, "rhs_eq")); - } - } void visit(const And &) override { llvm::Value *b = pop_bool(); llvm::Value *a = pop_bool(); diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp index d995c5281c4..f86c3e1a84a 100644 --- a/eval/src/vespa/eval/eval/node_types.cpp +++ b/eval/src/vespa/eval/eval/node_types.cpp @@ -90,9 +90,7 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser { void visit(const String &node) override { bind_type(ValueType::double_type(), node); } - void visit(const Array &node) override { - bind_type(ValueType::double_type(), node); - } + void visit(const In &node) override { resolve_op1(node); } void visit(const Neg &node) override { resolve_op1(node); } void visit(const Not &node) override { resolve_op1(node); } void visit(const If &node) override { @@ -139,9 +137,6 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser { void visit(const LessEqual &node) override { resolve_op2(node); } void visit(const Greater &node) override { resolve_op2(node); } void visit(const GreaterEqual &node) override { resolve_op2(node); } - void visit(const In &node) override { - bind_type(ValueType::double_type(), node); - } void visit(const And &node) override { resolve_op2(node); } void visit(const Or &node) override { resolve_op2(node); } void visit(const Cos &node) override { resolve_op1(node); } diff --git a/eval/src/vespa/eval/eval/node_visitor.h b/eval/src/vespa/eval/eval/node_visitor.h index 91d65ceb7ec..c5a6fd51373 100644 --- a/eval/src/vespa/eval/eval/node_visitor.h +++ b/eval/src/vespa/eval/eval/node_visitor.h @@ -22,7 +22,7 @@ struct NodeVisitor { virtual void visit(const nodes::Number &) = 0; virtual void visit(const nodes::Symbol &) = 0; virtual void visit(const nodes::String &) = 0; - virtual void visit(const nodes::Array &) = 0; + virtual void visit(const nodes::In &) = 0; virtual void visit(const nodes::Neg &) = 0; virtual void visit(const nodes::Not &) = 0; virtual void visit(const nodes::If &) = 0; @@ -50,7 +50,6 @@ struct NodeVisitor { virtual void visit(const nodes::LessEqual &) = 0; virtual void visit(const nodes::Greater &) = 0; virtual void visit(const nodes::GreaterEqual &) = 0; - virtual void visit(const nodes::In &) = 0; virtual void visit(const nodes::And &) = 0; virtual void visit(const nodes::Or &) = 0; @@ -92,7 +91,7 @@ struct EmptyNodeVisitor : NodeVisitor { void visit(const nodes::Number &) override {} void visit(const nodes::Symbol &) override {} void visit(const nodes::String &) override {} - void visit(const nodes::Array &) override {} + void visit(const nodes::In &) override {} void visit(const nodes::Neg &) override {} void visit(const nodes::Not &) override {} void visit(const nodes::If &) override {} @@ -116,7 +115,6 @@ struct EmptyNodeVisitor : NodeVisitor { void visit(const nodes::LessEqual &) override {} void visit(const nodes::Greater &) override {} void visit(const nodes::GreaterEqual &) override {} - void visit(const nodes::In &) override {} void visit(const nodes::And &) override {} void visit(const nodes::Or &) override {} void visit(const nodes::Cos &) override {} diff --git a/eval/src/vespa/eval/eval/operator_nodes.cpp b/eval/src/vespa/eval/eval/operator_nodes.cpp index 11817630da4..4c66268dfa2 100644 --- a/eval/src/vespa/eval/eval/operator_nodes.cpp +++ b/eval/src/vespa/eval/eval/operator_nodes.cpp @@ -37,23 +37,10 @@ OperatorRepo::OperatorRepo() : _map(), _max_size(0) { add(nodes::LessEqual()); add(nodes::Greater()); add(nodes::GreaterEqual()); - add(nodes::In()); add(nodes::And()); add(nodes::Or()); } -vespalib::string -In::dump(DumpContext &ctx) const -{ - vespalib::string str; - str += "("; - str += lhs().dump(ctx); - str += " in "; - str += rhs().dump(ctx); - str += ")"; - return str; -} - } // namespace vespalib::eval::nodes } // namespace vespalib::eval } // namespace vespalib diff --git a/eval/src/vespa/eval/eval/operator_nodes.h b/eval/src/vespa/eval/eval/operator_nodes.h index e4dda484d68..eafd817d42c 100644 --- a/eval/src/vespa/eval/eval/operator_nodes.h +++ b/eval/src/vespa/eval/eval/operator_nodes.h @@ -166,9 +166,6 @@ struct Less : OperatorHelper<Less> { Less() : Helper("<" struct LessEqual : OperatorHelper<LessEqual> { LessEqual() : Helper("<=", 10, LEFT) {}}; struct Greater : OperatorHelper<Greater> { Greater() : Helper(">", 10, LEFT) {}}; struct GreaterEqual : OperatorHelper<GreaterEqual> { GreaterEqual() : Helper(">=", 10, LEFT) {}}; -struct In : OperatorHelper<In> { In() : Helper("in", 10, LEFT) {} - virtual vespalib::string dump(DumpContext &ctx) const override; -}; struct And : OperatorHelper<And> { And() : Helper("&&", 2, LEFT) {}}; struct Or : OperatorHelper<Or> { Or() : Helper("||", 1, LEFT) {}}; diff --git a/eval/src/vespa/eval/eval/test/eval_spec.cpp b/eval/src/vespa/eval/eval/test/eval_spec.cpp index c7c0d754976..d214486cf21 100644 --- a/eval/src/vespa/eval/eval/test/eval_spec.cpp +++ b/eval/src/vespa/eval/eval/test/eval_spec.cpp @@ -104,12 +104,6 @@ EvalSpec::add_terminal_cases() { add_expression({}, "10").add_case({}, 10.0); add_expression({}, "100").add_case({}, 100.0); add_rule({"a", -5.0, 5.0}, "a", [](double a){ return a; }); - add_expression({}, "[]").add_case({}, 0.0); - add_expression({}, "[1]").add_case({}, 1.0); - add_expression({}, "[1,2]").add_case({}, 2.0); - add_expression({}, "[1,2,3]").add_case({}, 3.0); - add_expression({}, "[3,2,1]").add_case({}, 3.0); - add_expression({}, "[1,1,1,1,1]").add_case({}, 5.0); add_expression({}, "\"\"").add_case({}, vespalib::hash_code("")); add_expression({}, "\"foo\"").add_case({}, vespalib::hash_code("foo")); add_expression({}, "\"foo bar baz\"").add_case({}, vespalib::hash_code("foo bar baz")); @@ -277,52 +271,27 @@ EvalSpec::add_set_membership_cases() { add_expression({"a"}, "(a in [])") .add_case({0.0}, 0.0) - .add_case({1.0}, 0.0) - .add_case({2.0}, 0.0); - - add_expression({"a"}, "(a in [[]])") - .add_case({0.0}, 1.0) - .add_case({1.0}, 0.0) - .add_case({2.0}, 0.0); - - add_expression({"a"}, "(a in [[[]]])") - .add_case({0.0}, 0.0) - .add_case({1.0}, 1.0) - .add_case({2.0}, 0.0); - - add_expression({"a", "b"}, "(a in b)") - .add_case({my_nan, 2.0}, 0.0) - .add_case({2.0, my_nan}, 0.0) - .add_case({my_nan, my_nan}, 0.0) - .add_case({1.0, 2.0}, 0.0) - .add_case({2.0 - 1e-10, 2.0}, 0.0) - .add_case({2.0, 2.0}, 1.0) - .add_case({2.0 + 1e-10, 2.0}, 0.0) - .add_case({3.0, 2.0}, 0.0); - - add_expression({"a", "b"}, "(a in [b])") - .add_case({my_nan, 2.0}, 0.0) - .add_case({2.0, my_nan}, 0.0) - .add_case({my_nan, my_nan}, 0.0) - .add_case({1.0, 2.0}, 0.0) - .add_case({2.0 - 1e-10, 2.0}, 0.0) - .add_case({2.0, 2.0}, 1.0) - .add_case({2.0 + 1e-10, 2.0}, 0.0) - .add_case({3.0, 2.0}, 0.0); - - add_expression({"a", "b"}, "(a in [[b]])") - .add_case({1.0, 2.0}, 1.0) - .add_case({2.0, 2.0}, 0.0); - - add_expression({"a", "b", "c", "d"}, "(a in [b,c,d])") - .add_case({0.0, 10.0, 20.0, 30.0}, 0.0) - .add_case({3.0, 10.0, 20.0, 30.0}, 0.0) - .add_case({10.0, 10.0, 20.0, 30.0}, 1.0) - .add_case({20.0, 10.0, 20.0, 30.0}, 1.0) - .add_case({30.0, 10.0, 20.0, 30.0}, 1.0) - .add_case({10.0, 30.0, 20.0, 10.0}, 1.0) - .add_case({20.0, 30.0, 20.0, 10.0}, 1.0) - .add_case({30.0, 30.0, 20.0, 10.0}, 1.0); + .add_case({1.0}, 0.0); + + add_expression({"a"}, "(a in [2.0])") + .add_case({my_nan}, 0.0) + .add_case({1.0}, 0.0) + .add_case({2.0 - 1e-10}, 0.0) + .add_case({2.0}, 1.0) + .add_case({2.0 + 1e-10}, 0.0) + .add_case({3.0}, 0.0); + + add_expression({"a"}, "(a in [10,20,30])") + .add_case({0.0}, 0.0) + .add_case({3.0}, 0.0) + .add_case({10.0}, 1.0) + .add_case({20.0}, 1.0) + .add_case({30.0}, 1.0); + + add_expression({"a"}, "(a in [30,20,10])") + .add_case({10.0}, 1.0) + .add_case({20.0}, 1.0) + .add_case({30.0}, 1.0); } void diff --git a/eval/src/vespa/eval/eval/vm_forest.cpp b/eval/src/vespa/eval/eval/vm_forest.cpp index 4a73394e354..c456c660af7 100644 --- a/eval/src/vespa/eval/eval/vm_forest.cpp +++ b/eval/src/vespa/eval/eval/vm_forest.cpp @@ -128,20 +128,13 @@ void encode_in(const nodes::In &in, std::vector<uint32_t> &model_out) { size_t meta_idx = model_out.size(); - auto symbol = nodes::as<nodes::Symbol>(in.lhs()); + auto symbol = nodes::as<nodes::Symbol>(in.child()); assert(symbol); model_out.push_back(uint32_t(symbol->id()) << 12); - assert(in.rhs().is_const()); - auto array = nodes::as<nodes::Array>(in.rhs()); size_t set_size_idx = model_out.size(); - if (array) { - model_out.push_back(array->size()); - for (size_t i = 0; i < array->size(); ++i) { - encode_const(array->get(i).get_const_value(), model_out); - } - } else { - model_out.push_back(1); - encode_const(in.rhs().get_const_value(), model_out); + model_out.push_back(in.num_entries()); + for (size_t i = 0; i < in.num_entries(); ++i) { + encode_const(in.get_entry(i).get_const_value(), model_out); } size_t left_idx = model_out.size(); uint32_t left_type = encode_node(left_child, model_out); |