summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2017-11-03 11:44:25 +0000
committerHåvard Pettersen <havardpe@oath.com>2017-11-03 12:35:31 +0000
commit228a4f089d2431bf2012982bce9e093f2df2dead (patch)
tree01119a93d453585edf6c88017c24abff534b0b21 /eval
parent5a69cb546ad5661001c89eecbf13c8b41b57019c (diff)
handle 'in' operator as custom (tensor) map operation
free arrays no longer allowed restrict set members to be numbers or strings auto-unbox negative numbers in AST
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/eval/compiled_function/compiled_function_test.cpp15
-rw-r--r--eval/src/tests/eval/function/function_test.cpp101
-rw-r--r--eval/src/tests/eval/gbdt/gbdt_test.cpp32
-rw-r--r--eval/src/tests/eval/node_types/node_types_test.cpp21
-rw-r--r--eval/src/vespa/eval/eval/basic_nodes.cpp4
-rw-r--r--eval/src/vespa/eval/eval/basic_nodes.h47
-rw-r--r--eval/src/vespa/eval/eval/function.cpp64
-rw-r--r--eval/src/vespa/eval/eval/gbdt.cpp6
-rw-r--r--eval/src/vespa/eval/eval/interpreted_function.cpp50
-rw-r--r--eval/src/vespa/eval/eval/key_gen.cpp8
-rw-r--r--eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp64
-rw-r--r--eval/src/vespa/eval/eval/node_types.cpp7
-rw-r--r--eval/src/vespa/eval/eval/node_visitor.h6
-rw-r--r--eval/src/vespa/eval/eval/operator_nodes.cpp13
-rw-r--r--eval/src/vespa/eval/eval/operator_nodes.h3
-rw-r--r--eval/src/vespa/eval/eval/test/eval_spec.cpp73
-rw-r--r--eval/src/vespa/eval/eval/vm_forest.cpp15
17 files changed, 224 insertions, 305 deletions
diff --git a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp
index b887c6e45f9..0e9806d5381 100644
--- a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp
+++ b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp
@@ -128,16 +128,13 @@ TEST_FF("require that compiled evaluation passes all conformance tests", MyEvalT
//-----------------------------------------------------------------------------
TEST("require that large (plugin) set membership checks work") {
- nodes::Array my_set;
+ auto my_in = std::make_unique<nodes::In>(std::make_unique<nodes::Symbol>(0));
for(size_t i = 1; i <= 100; ++i) {
- my_set.add(nodes::Node_UP(new nodes::Number(i)));
+ my_in->add_entry(std::make_unique<nodes::Number>(i));
}
- nodes::DumpContext dump_ctx({});
- vespalib::string expr = vespalib::make_string("if(a in %s,1,0)",
- my_set.dump(dump_ctx).c_str());
- // fprintf(stderr, "expression: %s\n", expr.c_str());
- CompiledFunction cf(Function::parse(expr), PassParams::SEPARATE);
- CompiledFunction arr_cf(Function::parse(expr), PassParams::ARRAY);
+ Function my_fun(std::move(my_in), {"a"});
+ CompiledFunction cf(my_fun, PassParams::SEPARATE);
+ CompiledFunction arr_cf(my_fun, PassParams::ARRAY);
auto fun = cf.get_function<1>();
auto arr_fun = arr_cf.get_function();
for (double value = 0.5; value <= 100.5; value += 0.5) {
@@ -146,7 +143,7 @@ TEST("require that large (plugin) set membership checks work") {
EXPECT_EQUAL(1.0, arr_fun(&value));
} else {
EXPECT_EQUAL(0.0, fun(value));
- EXPECT_EQUAL(0.0, arr_fun(&value));
+ EXPECT_EQUAL(0.0, arr_fun(&value));
}
}
}
diff --git a/eval/src/tests/eval/function/function_test.cpp b/eval/src/tests/eval/function/function_test.cpp
index 75a6df41b50..6c3839b6cc9 100644
--- a/eval/src/tests/eval/function/function_test.cpp
+++ b/eval/src/tests/eval/function/function_test.cpp
@@ -81,6 +81,11 @@ bool verify_string(const vespalib::string &str, const vespalib::string &expr) {
return ok;
}
+void verify_error(const vespalib::string &expr, const vespalib::string &expected_error) {
+ Function function = Function::parse(params, expr);
+ EXPECT_TRUE(function.has_error());
+ EXPECT_EQUAL(expected_error, function.get_error());
+}
TEST("require that scientific numbers can be parsed") {
EXPECT_EQUAL(1.0, as_number(Function::parse(params, "1")));
@@ -163,18 +168,16 @@ TEST("require that strings are parsed and dumped correctly") {
}
}
-TEST("require that arrays can be parsed") {
- EXPECT_EQUAL("[]", Function::parse(params, "[]").dump());
- EXPECT_EQUAL("[1,2,3]", Function::parse(params, "[1,2,3]").dump());
- EXPECT_EQUAL("[1,2,3]", Function::parse(params, "[ 1 , 2 , 3 ]").dump());
- EXPECT_EQUAL("[[x],[x,y],[1,2,[z,w]]]", Function::parse(params, "[[x],[x,y],[1,2,[z,w]]]").dump());
- EXPECT_EQUAL("[(x+1),(y-[3,7]),z,[]]", Function::parse(params, "[x+1,y-[3,7],z,[]]").dump());
+TEST("require that free arrays cannot be parsed") {
+ verify_error("[1,2,3]", "[]...[missing value]...[[1,2,3]]");
}
TEST("require that negative values can be parsed") {
- EXPECT_EQUAL("(-1)", Function::parse(params, "-1").dump());
- EXPECT_EQUAL("(-2.5)", Function::parse(params, "-2.5").dump());
- EXPECT_EQUAL("(-100)", Function::parse(params, "-100").dump());
+ EXPECT_EQUAL("-1", Function::parse(params, "-1").dump());
+ EXPECT_EQUAL("1", Function::parse(params, "--1").dump());
+ EXPECT_EQUAL("-1", Function::parse(params, " ( - ( - ( - ( (1) ) ) ) )").dump());
+ EXPECT_EQUAL("-2.5", Function::parse(params, "-2.5").dump());
+ EXPECT_EQUAL("-100", Function::parse(params, "-100").dump());
}
TEST("require that negative symbols can be parsed") {
@@ -206,7 +209,7 @@ TEST("require that operators have appropriate binding order") {
verify_operator_binding_order({ { Operator::Order::RIGHT, { "^" } },
{ Operator::Order::LEFT, { "*", "/", "%" } },
{ Operator::Order::LEFT, { "+", "-" } },
- { Operator::Order::LEFT, { "==", "!=", "~=", "<", "<=", ">", ">=", "in" } },
+ { Operator::Order::LEFT, { "==", "!=", "~=", "<", "<=", ">", ">=" } },
{ Operator::Order::LEFT, { "&&" } },
{ Operator::Order::LEFT, { "||" } } });
}
@@ -248,10 +251,31 @@ TEST("require that operators can not bind out of parenthesis") {
}
TEST("require that set membership constructs can be parsed") {
- EXPECT_EQUAL("(x in [y,z,w])", Function::parse(params, "x in [y,z,w]").dump());
- EXPECT_EQUAL("(x in [y,z,w])", Function::parse(params, "x in[y,z,w]").dump());
- EXPECT_EQUAL("(x in [y,z,w])", Function::parse(params, "(x)in[y,z,w]").dump());
- EXPECT_EQUAL("((x+1) in [y,z,(w-1)])", Function::parse(params, "(x+1)in[y,z,(w-1)]").dump());
+ EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "x in [1,2,3]").dump());
+ EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "x in [ 1 , 2 , 3 ] ").dump());
+ EXPECT_EQUAL("(x in [-1,-2,-3])", Function::parse(params, "x in [-1,-2,-3]").dump());
+ EXPECT_EQUAL("(x in [-1,-2,-3])", Function::parse(params, "x in [ - 1 , - 2 , - 3 ]").dump());
+ EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "x in[1,2,3]").dump());
+ EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "(x)in[1,2,3]").dump());
+ EXPECT_EQUAL("(x in [\"a\",2,\"c\"])", Function::parse(params, "x in [\"a\",2,\"c\"]").dump());
+}
+
+TEST("require that set membership entries must be array of strings/numbers") {
+ verify_error("x in 1", "[x in ]...[expected '[', but got '1']...[1]");
+ verify_error("x in ([1])", "[x in ]...[expected '[', but got '(']...[([1])]");
+ verify_error("x in [y]", "[x in [y]...[invalid entry for 'in' operator]...[]]");
+ verify_error("x in [!1]", "[x in [!1]...[invalid entry for 'in' operator]...[]]");
+ verify_error("x in [1+2]", "[x in [1]...[expected ',', but got '+']...[+2]]");
+ verify_error("x in [-\"foo\"]", "[x in [-\"foo\"]...[invalid entry for 'in' operator]...[]]");
+}
+
+TEST("require that set membership binds to the next value") {
+ EXPECT_EQUAL("((x in [1,2,3])^2)", Function::parse(params, "x in [1,2,3]^2").dump());
+}
+
+TEST("require that set membership binds to the left with appropriate precedence") {
+ EXPECT_EQUAL("((x<y) in [1,2,3])", Function::parse(params, "x < y in [1,2,3]").dump());
+ EXPECT_EQUAL("(x&&(y in [1,2,3]))", Function::parse(params, "x && y in [1,2,3]").dump());
}
TEST("require that function calls can be parsed") {
@@ -309,22 +333,12 @@ TEST("require that leaf nodes have no children") {
EXPECT_EQUAL(0u, Function::parse("\"abc\"").root().num_children());
}
-TEST("require that Array children can be accessed") {
- Function f = Function::parse("[1,2,3]");
- const Node &root = f.root();
- EXPECT_TRUE(!root.is_leaf());
- ASSERT_EQUAL(3u, root.num_children());
- EXPECT_EQUAL(1.0, root.get_child(0).get_const_value());
- EXPECT_EQUAL(2.0, root.get_child(1).get_const_value());
- EXPECT_EQUAL(3.0, root.get_child(2).get_const_value());
-}
-
TEST("require that Neg child can be accessed") {
- Function f = Function::parse("-1");
+ Function f = Function::parse("-x");
const Node &root = f.root();
EXPECT_TRUE(!root.is_leaf());
ASSERT_EQUAL(1u, root.num_children());
- EXPECT_EQUAL(1.0, root.get_child(0).get_const_value());
+ EXPECT_TRUE(root.get_child(0).is_param());
}
TEST("require that Not child can be accessed") {
@@ -386,7 +400,7 @@ TEST("require that children can be detached") {
EXPECT_EQUAL(1u, detach_from_root("-a"));
EXPECT_EQUAL(1u, detach_from_root("!a"));
EXPECT_EQUAL(3u, detach_from_root("if(1,2,3)"));
- EXPECT_EQUAL(5u, detach_from_root("[1,2,3,4,5]"));
+ EXPECT_EQUAL(1u, detach_from_root("a in [1,2,3,4,5]"));
EXPECT_EQUAL(2u, detach_from_root("a+b"));
EXPECT_EQUAL(1u, detach_from_root("isNan(a)"));
EXPECT_EQUAL(2u, detach_from_root("max(a,b)"));
@@ -456,7 +470,7 @@ TEST("require that traversal works as expected") {
EXPECT_TRUE(verify_expression_traversal("1"));
EXPECT_TRUE(verify_expression_traversal("1+2"));
EXPECT_TRUE(verify_expression_traversal("1+2*3-4/5"));
- EXPECT_TRUE(verify_expression_traversal("if(x,1+2*3,[a,b,c]/5)"));
+ EXPECT_TRUE(verify_expression_traversal("if(x,1+2*3,if(a,b,c)/5)"));
}
//-----------------------------------------------------------------------------
@@ -492,14 +506,6 @@ TEST("require that string is const") {
EXPECT_TRUE(Function::parse("\"x\"").root().is_const());
}
-TEST("require that array is const if all elements are const") {
- EXPECT_TRUE(Function::parse("[1,2,3]").root().is_const());
- EXPECT_TRUE(!Function::parse("[x,2,3]").root().is_const());
- EXPECT_TRUE(!Function::parse("[1,y,3]").root().is_const());
- EXPECT_TRUE(!Function::parse("[1,2,z]").root().is_const());
- EXPECT_TRUE(!Function::parse("[x,y,z]").root().is_const());
-}
-
TEST("require that neg is const if sub-expression is const") {
EXPECT_TRUE(Function::parse("-123").root().is_const());
EXPECT_TRUE(!Function::parse("-x").root().is_const());
@@ -517,11 +523,11 @@ TEST("require that operators are cost if both children are const") {
EXPECT_TRUE(Function::parse("1+2").root().is_const());
}
-TEST("require that set membership is const only if array elements are const") {
+TEST("require that set membership is never tagged as const (NB: avoids jit recursion)") {
EXPECT_TRUE(!Function::parse("x in [x,y,z]").root().is_const());
EXPECT_TRUE(!Function::parse("1 in [x,y,z]").root().is_const());
EXPECT_TRUE(!Function::parse("1 in [1,y,z]").root().is_const());
- EXPECT_TRUE(Function::parse("1 in [1,2,3]").root().is_const());
+ EXPECT_TRUE(!Function::parse("1 in [1,2,3]").root().is_const());
}
TEST("require that calls are cost if all parameters are const") {
@@ -554,10 +560,8 @@ TEST("require that feature in set of constants is tree if children are trees or
EXPECT_TRUE(Function::parse("if (foo in [1, 2], if(bar < 3, 4, 5), 6)").root().is_tree());
EXPECT_TRUE(Function::parse("if (foo in [1, 2], if(bar < 3, 4, 5), if(baz < 6, 7, 8))").root().is_tree());
EXPECT_TRUE(Function::parse("if (foo in [1, 2], 3, if(baz < 4, 5, 6))").root().is_tree());
- EXPECT_TRUE(Function::parse("if (foo in [min(1,2), max(1,2)], 3, 4)").root().is_tree());
+ EXPECT_TRUE(Function::parse("if (foo in [1, 2], min(1,3), max(1,4))").root().is_tree());
EXPECT_TRUE(!Function::parse("if (1 in [1, 2], 3, 4)").root().is_tree());
- EXPECT_TRUE(!Function::parse("if (1 in [foo, 2], 3, 4)").root().is_tree());
- EXPECT_TRUE(!Function::parse("if (foo in [bar, 2], 3, 4)").root().is_tree());
}
TEST("require that sums of trees and forests are forests") {
@@ -671,14 +675,17 @@ TEST("require that unknown function that is valid parameter works as expected wi
EXPECT_EQUAL("[z(x)]...[unknown symbol: 'z(x)']...[+y]", Function::parse(params, "z(x)+y", MySymbolExtractor({'(', ')'})).dump());
}
-//-----------------------------------------------------------------------------
-
-void verify_error(const vespalib::string &expr, const vespalib::string &expected_error) {
- Function function = Function::parse(params, expr);
- EXPECT_TRUE(function.has_error());
- EXPECT_EQUAL(expected_error, function.get_error());
+TEST("require that custom symbol extractor is not invoked for known function call") {
+ MySymbolExtractor extractor;
+ EXPECT_EQUAL(extractor.invoke_count, 0u);
+ EXPECT_EQUAL("[bogus]...[unknown symbol: 'bogus']...[(1,2)]", Function::parse(params, "bogus(1,2)", extractor).dump());
+ EXPECT_EQUAL(extractor.invoke_count, 1u);
+ EXPECT_EQUAL("max(1,2)", Function::parse(params, "max(1,2)", extractor).dump());
+ EXPECT_EQUAL(extractor.invoke_count, 1u);
}
+//-----------------------------------------------------------------------------
+
TEST("require that valid function does not report parse error") {
Function function = Function::parse(params, "x + y");
EXPECT_TRUE(!function.has_error());
diff --git a/eval/src/tests/eval/gbdt/gbdt_test.cpp b/eval/src/tests/eval/gbdt/gbdt_test.cpp
index f4f5970cabe..20969a1e3b4 100644
--- a/eval/src/tests/eval/gbdt/gbdt_test.cpp
+++ b/eval/src/tests/eval/gbdt/gbdt_test.cpp
@@ -31,14 +31,14 @@ TEST("require that tree stats can be calculated") {
EXPECT_EQUAL(tree_size, TreeStats(Function::parse(Model().make_tree(tree_size)).root()).size);
}
- TreeStats stats1(Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))").root());
+ TreeStats stats1(Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))").root());
EXPECT_EQUAL(3u, stats1.num_params);
EXPECT_EQUAL(4u, stats1.size);
EXPECT_EQUAL(1u, stats1.num_less_checks);
EXPECT_EQUAL(2u, stats1.num_in_checks);
EXPECT_EQUAL(3u, stats1.max_set_size);
- TreeStats stats2(Function::parse("if((d in 1),10.0,if((e<1),20.0,30.0))").root());
+ TreeStats stats2(Function::parse("if((d in [1]),10.0,if((e<1),20.0,30.0))").root());
EXPECT_EQUAL(2u, stats2.num_params);
EXPECT_EQUAL(3u, stats2.size);
EXPECT_EQUAL(1u, stats2.num_less_checks);
@@ -61,9 +61,9 @@ TEST("require that trees can be extracted from forest") {
}
TEST("require that forest stats can be calculated") {
- Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+"
- "if((d in 1),10.0,if((e<1),20.0,30.0))+"
- "if((d in 1),10.0,if((e<1),20.0,30.0))");
+ Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+"
+ "if((d in [1]),10.0,if((e<1),20.0,30.0))+"
+ "if((d in [1]),10.0,if((e<1),20.0,30.0))");
std::vector<const Node *> trees = extract_trees(function.root());
ForestStats stats(trees);
EXPECT_EQUAL(5u, stats.num_params);
@@ -261,8 +261,8 @@ TEST("require that models with in checks are rejected by less only vm optimizer"
}
TEST("require that general VM tree optimizer works") {
- Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+"
- "if((d in 1),10.0,if((e<1),if((f<1),20.0,30.0),40.0))");
+ Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+"
+ "if((d in [1]),10.0,if((e<1),if((f<1),20.0,30.0),40.0))");
CompiledFunction compiled_function(function, PassParams::ARRAY, general_vm_chain);
EXPECT_EQUAL(1u, compiled_function.get_forests().size());
auto f = compiled_function.get_function();
@@ -324,17 +324,17 @@ TEST("require that forests evaluate to approximately the same for all evaluation
//-----------------------------------------------------------------------------
TEST("require that GDBT expressions can be detected") {
- Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+"
- "if((d in 1),10.0,if((e<1),20.0,30.0))+"
- "if((d in 1),10.0,if((e<1),20.0,30.0))");
+ Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+"
+ "if((d in [1]),10.0,if((e<1),20.0,30.0))+"
+ "if((d in [1]),10.0,if((e<1),20.0,30.0))");
EXPECT_TRUE(contains_gbdt(function.root(), 9));
EXPECT_TRUE(!contains_gbdt(function.root(), 10));
}
TEST("require that wrapped GDBT expressions can be detected") {
- Function function = Function::parse("10*(if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+"
- "if((d in 1),10.0,if((e<1),20.0,30.0))+"
- "if((d in 1),10.0,if((e<1),20.0,30.0)))");
+ Function function = Function::parse("10*(if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+"
+ "if((d in [1]),10.0,if((e<1),20.0,30.0))+"
+ "if((d in [1]),10.0,if((e<1),20.0,30.0)))");
EXPECT_TRUE(contains_gbdt(function.root(), 9));
EXPECT_TRUE(!contains_gbdt(function.root(), 10));
}
@@ -345,9 +345,9 @@ TEST("require that lazy parameters are not suggested for GBDT models") {
}
TEST("require that lazy parameters can be suggested for small GBDT models") {
- Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+"
- "if((d in 1),10.0,if((e<1),20.0,30.0))+"
- "if((d in 1),10.0,if((e<1),20.0,30.0))");
+ Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+"
+ "if((d in [1]),10.0,if((e<1),20.0,30.0))+"
+ "if((d in [1]),10.0,if((e<1),20.0,30.0))");
EXPECT_TRUE(CompiledFunction::should_use_lazy_params(function));
}
diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp
index e8009447793..0b6ea9e4d35 100644
--- a/eval/src/tests/eval/node_types/node_types_test.cpp
+++ b/eval/src/tests/eval/node_types/node_types_test.cpp
@@ -81,12 +81,6 @@ TEST("require that input parameters preserve their type") {
TEST_DO(verify("tensor(x{},y[10],z[])", "tensor(x{},y[10],z[])"));
}
-TEST("require that arrays are double (size) unless they contain an error") {
- TEST_DO(verify("[1,2,3]", "double"));
- TEST_DO(verify("[any,tensor,double]", "double"));
- TEST_DO(verify("[1,error,3]", "error"));
-}
-
TEST("require that if resolves to the appropriate type") {
TEST_DO(verify("if(error,1,2)", "error"));
TEST_DO(verify("if(1,error,2)", "error"));
@@ -108,17 +102,6 @@ TEST("require that if resolves to the appropriate type") {
TEST_DO(verify("if(double,any,double)", "any"));
}
-TEST("require that set membership resolves to double unless error") {
- TEST_DO(verify("1 in [1,2,3]", "double"));
- TEST_DO(verify("1 in [tensor,tensor,tensor]", "double"));
- TEST_DO(verify("1 in tensor", "double"));
- TEST_DO(verify("tensor in 1", "double"));
- TEST_DO(verify("tensor in [1,2,any]", "double"));
- TEST_DO(verify("any in [1,tensor,any]", "double"));
- TEST_DO(verify("error in [1,tensor,any]", "error"));
- TEST_DO(verify("any in [tensor,error,any]", "error"));
-}
-
TEST("require that reduce resolves correct type") {
TEST_DO(verify("reduce(error,sum)", "error"));
TEST_DO(verify("reduce(tensor,sum)", "double"));
@@ -244,6 +227,10 @@ TEST("require that map resolves correct type") {
TEST_DO(verify_op1("map(%s,f(x)(sin(x)))"));
}
+TEST("require that set membership resolves correct type") {
+ TEST_DO(verify_op1("%s in [1,2,3]"));
+}
+
TEST("require that join resolves correct type") {
TEST_DO(verify_op2("join(%s,%s,f(x,y)(x+y))"));
}
diff --git a/eval/src/vespa/eval/eval/basic_nodes.cpp b/eval/src/vespa/eval/eval/basic_nodes.cpp
index 50db7370a66..a96d634e07a 100644
--- a/eval/src/vespa/eval/eval/basic_nodes.cpp
+++ b/eval/src/vespa/eval/eval/basic_nodes.cpp
@@ -60,7 +60,7 @@ Node::traverse(NodeTraverser &traverser) const
void Number::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void Symbol::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void String::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void Array ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void In ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void Neg ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void Not ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void If ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
@@ -124,7 +124,7 @@ If::If(Node_UP cond_in, Node_UP true_expr_in, Node_UP false_expr_in, double p_tr
if (less) {
_is_tree = (less->lhs().is_param() && less->rhs().is_const());
} else if (in) {
- _is_tree = (in->lhs().is_param() && in->rhs().is_const());
+ _is_tree = in->child().is_param();
}
}
}
diff --git a/eval/src/vespa/eval/eval/basic_nodes.h b/eval/src/vespa/eval/eval/basic_nodes.h
index 3f4c21be810..ebf65178b99 100644
--- a/eval/src/vespa/eval/eval/basic_nodes.h
+++ b/eval/src/vespa/eval/eval/basic_nodes.h
@@ -142,40 +142,39 @@ public:
void accept(NodeVisitor &visitor) const override;
};
-class Array : public Node {
+class In : public Node {
private:
- std::vector<Node_UP> _nodes;
- bool _is_const;
+ Node_UP _child;
+ std::vector<Node_UP> _entries;
public:
- Array() : _nodes(), _is_const(false) {}
- bool is_const() const override { return _is_const; }
- size_t size() const { return _nodes.size(); }
- const Node &get(size_t i) const { return *_nodes[i]; }
- size_t num_children() const override { return size(); }
- const Node &get_child(size_t idx) const override { return get(idx); }
- void detach_children(NodeHandler &handler) override {
- for (size_t i = 0; i < _nodes.size(); ++i) {
- handler.handle(std::move(_nodes[i]));
- }
- _nodes.clear();
+ In(Node_UP child) : _child(std::move(child)), _entries() {}
+ void add_entry(Node_UP entry) {
+ assert(entry->is_const());
+ _entries.push_back(std::move(entry));
}
- void add(Node_UP node) {
- if (_nodes.empty()) {
- _is_const = node->is_const();
- } else {
- _is_const = (_is_const && node->is_const());
- }
- _nodes.push_back(std::move(node));
+ size_t num_entries() const { return _entries.size(); }
+ const Node &get_entry(size_t idx) const { return *_entries[idx]; }
+ const Node &child() const { return *_child; }
+ size_t num_children() const override { return _child ? 1 : 0; }
+ const Node &get_child(size_t idx) const override {
+ (void) idx;
+ assert(idx == 0);
+ return child();
+ }
+ void detach_children(NodeHandler &handler) override {
+ handler.handle(std::move(_child));
}
vespalib::string dump(DumpContext &ctx) const override {
vespalib::string str;
- str += "[";
+ str += "(";
+ str += _child->dump(ctx);
+ str += " in [";
CommaTracker node_list;
- for (const auto &node: _nodes) {
+ for (const auto &node: _entries) {
node_list.maybe_comma(str);
str += node->dump(ctx);
}
- str += "]";
+ str += "])";
return str;
}
void accept(NodeVisitor &visitor) const override;
diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp
index 0843baa700c..c4f91067260 100644
--- a/eval/src/vespa/eval/eval/function.cpp
+++ b/eval/src/vespa/eval/eval/function.cpp
@@ -281,12 +281,15 @@ public:
size_t operator_mark() const { return _operator_mark; }
void operator_mark(size_t mark) { _operator_mark = mark; }
- void push_operator(Operator_UP node) {
+ void apply_until(const nodes::Operator &op) {
while ((_operator_stack.size() > _operator_mark) &&
- (_operator_stack.back()->do_before(*node)))
+ (_operator_stack.back()->do_before(op)))
{
apply_operator();
}
+ }
+ void push_operator(Operator_UP node) {
+ apply_until(*node);
_operator_stack.push_back(std::move(node));
}
Operator_UP pop_operator() {
@@ -299,6 +302,7 @@ public:
//-----------------------------------------------------------------------------
+void parse_value(ParseContext &ctx);
void parse_expression(ParseContext &ctx);
int unhex(char c) {
@@ -642,8 +646,11 @@ void parse_symbol_or_call(ParseContext &ctx) {
}
}
-void parse_array(ParseContext &ctx) {
- std::unique_ptr<nodes::Array> array(new nodes::Array());
+void parse_in(ParseContext &ctx)
+{
+ ctx.apply_until(nodes::Less());
+ auto in = std::make_unique<nodes::In>(ctx.pop_expression());
+ ctx.skip_spaces();
ctx.eat('[');
ctx.skip_spaces();
size_t size = 0;
@@ -651,11 +658,19 @@ void parse_array(ParseContext &ctx) {
if (++size > 1) {
ctx.eat(',');
}
- parse_expression(ctx);
- array->add(ctx.pop_expression());
+ parse_value(ctx);
+ ctx.skip_spaces();
+ auto entry = ctx.pop_expression();
+ auto num = nodes::as<nodes::Number>(*entry);
+ auto str = nodes::as<nodes::String>(*entry);
+ if (num || str) {
+ in->add_entry(std::move(entry));
+ } else {
+ ctx.fail("invalid entry for 'in' operator");
+ }
}
ctx.eat(']');
- ctx.push_expression(std::move(array));
+ ctx.push_expression(std::move(in));
}
void parse_value(ParseContext &ctx) {
@@ -663,7 +678,13 @@ void parse_value(ParseContext &ctx) {
if (ctx.get() == '-') {
ctx.next();
parse_value(ctx);
- ctx.push_expression(Node_UP(new nodes::Neg(ctx.pop_expression())));
+ auto entry = ctx.pop_expression();
+ auto num = nodes::as<nodes::Number>(*entry);
+ if (num) {
+ ctx.push_expression(std::make_unique<nodes::Number>(-num->value()));
+ } else {
+ ctx.push_expression(std::make_unique<nodes::Neg>(std::move(entry)));
+ }
} else if (ctx.get() == '!') {
ctx.next();
parse_value(ctx);
@@ -672,8 +693,6 @@ void parse_value(ParseContext &ctx) {
ctx.next();
parse_expression(ctx);
ctx.eat(')');
- } else if (ctx.get() == '[') {
- parse_array(ctx);
} else if (ctx.get() == '"') {
parse_string(ctx);
} else if (isdigit(ctx.get())) {
@@ -683,7 +702,8 @@ void parse_value(ParseContext &ctx) {
}
}
-void parse_operator(ParseContext &ctx) {
+bool parse_operator(ParseContext &ctx) {
+ bool expect_value = true;
ctx.skip_spaces();
vespalib::string &str = ctx.peek(ctx.scratch(), nodes::OperatorRepo::instance().max_size());
Operator_UP op = nodes::OperatorRepo::instance().create(str);
@@ -691,24 +711,38 @@ void parse_operator(ParseContext &ctx) {
ctx.push_operator(std::move(op));
ctx.skip(str.size());
} else {
- ctx.fail(make_string("invalid operator: '%c'", ctx.get()));
+ vespalib::string ident = get_ident(ctx, true);
+ if (ident == "in") {
+ parse_in(ctx);
+ expect_value = false;
+ } else {
+ if (ident.empty()) {
+ ctx.fail(make_string("invalid operator: '%c'", ctx.get()));
+ } else {
+ ctx.fail(make_string("invalid operator: '%s'", ident.c_str()));
+ }
+ }
}
+ return expect_value;
}
void parse_expression(ParseContext &ctx) {
size_t old_mark = ctx.operator_mark();
ctx.operator_mark(ctx.num_operators());
+ bool expect_value = true;
for (;;) {
- parse_value(ctx);
+ if (expect_value) {
+ parse_value(ctx);
+ }
ctx.skip_spaces();
- if (ctx.eos() || ctx.get() == ')' || ctx.get() == ',' || ctx.get() == ']') {
+ if (ctx.eos() || ctx.get() == ')' || ctx.get() == ',') {
while (ctx.num_operators() > ctx.operator_mark()) {
ctx.apply_operator();
}
ctx.operator_mark(old_mark);
return;
}
- parse_operator(ctx);
+ expect_value = parse_operator(ctx);
}
}
diff --git a/eval/src/vespa/eval/eval/gbdt.cpp b/eval/src/vespa/eval/eval/gbdt.cpp
index cffd8bdb0c3..0edc42070ba 100644
--- a/eval/src/vespa/eval/eval/gbdt.cpp
+++ b/eval/src/vespa/eval/eval/gbdt.cpp
@@ -70,13 +70,11 @@ TreeStats::traverse(const nodes::Node &node, size_t depth, size_t &sum_path) {
++num_less_checks;
} else {
assert(in);
- auto symbol = nodes::as<nodes::Symbol>(in->lhs());
+ auto symbol = nodes::as<nodes::Symbol>(in->child());
assert(symbol);
num_params = std::max(num_params, size_t(symbol->id() + 1));
++num_in_checks;
- auto array = nodes::as<nodes::Array>(in->rhs());
- size_t array_size = (array) ? array->size() : 1;
- max_set_size = std::max(max_set_size, array_size);
+ max_set_size = std::max(max_set_size, in->num_entries());
}
return 1.0 + (p_true * true_path) + ((1.0 - p_true) * false_path);
} else {
diff --git a/eval/src/vespa/eval/eval/interpreted_function.cpp b/eval/src/vespa/eval/eval/interpreted_function.cpp
index caf71fcb68b..f99c4ace2dd 100644
--- a/eval/src/vespa/eval/eval/interpreted_function.cpp
+++ b/eval/src/vespa/eval/eval/interpreted_function.cpp
@@ -65,24 +65,6 @@ void op_skip_if_false(State &state, uint64_t param) {
//-----------------------------------------------------------------------------
-// compare lhs with a set member, short-circuit if found
-void op_check_member(State &state, uint64_t param) {
- if (state.peek(1).equal(state.peek(0))) {
- state.replace(2, state.stash.create<DoubleValue>(1.0));
- state.program_offset += param;
- } else {
- state.stack.pop_back();
- }
-}
-
-// set member not found, replace lhs with false
-void op_not_member(State &state, uint64_t) {
- state.stack.pop_back();
- state.stack.push_back(state.stash.create<DoubleValue>(0.0));
-}
-
-//-----------------------------------------------------------------------------
-
void op_double_map(State &state, uint64_t param) {
state.replace(1, state.stash.create<DoubleValue>(to_map_fun(param)(state.peek(0).as_double())));
}
@@ -252,8 +234,14 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser {
void visit(const String &node) override {
make_const_op(node, stash.create<DoubleValue>(node.hash()));
}
- void visit(const Array &node) override {
- make_const_op(node, stash.create<DoubleValue>(node.size()));
+ void visit(const In &node) override {
+ auto my_in = std::make_unique<In>(std::make_unique<Symbol>(0));
+ for (size_t i = 0; i < node.num_entries(); ++i) {
+ my_in->add_entry(std::make_unique<Number>(node.get_entry(i).get_const_value()));
+ }
+ Function my_fun(std::move(my_in), {"x"});
+ const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(my_fun, PassParams::SEPARATE));
+ make_map_op(node, token.get()->get().get_function<1>());
}
void visit(const Neg &node) override {
make_map_op(node, operation::Neg::f);
@@ -367,26 +355,6 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser {
void visit(const GreaterEqual &node) override {
make_join_op(node, operation::GreaterEqual::f);
}
- void visit(const In &node) override {
- std::vector<size_t> checks;
- node.lhs().traverse(*this);
- auto array = as<Array>(node.rhs());
- if (array) {
- for (size_t i = 0; i < array->size(); ++i) {
- array->get(i).traverse(*this);
- checks.push_back(program.size());
- program.emplace_back(op_check_member);
- }
- } else {
- node.rhs().traverse(*this);
- checks.push_back(program.size());
- program.emplace_back(op_check_member);
- }
- for (size_t i = 0; i < checks.size(); ++i) {
- program[checks[i]].update_param(program.size() - checks[i]);
- }
- program.emplace_back(op_not_member);
- }
void visit(const And &node) override {
make_join_op(node, operation::And::f);
}
@@ -472,7 +440,7 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser {
//-------------------------------------------------------------------------
bool open(const Node &node) override {
- if (check_type<Array, If, In>(node)) {
+ if (check_type<If>(node)) {
node.accept(*this);
return false;
}
diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp
index a745023fcf4..e0494e1fe11 100644
--- a/eval/src/vespa/eval/eval/key_gen.cpp
+++ b/eval/src/vespa/eval/eval/key_gen.cpp
@@ -25,7 +25,12 @@ struct KeyGen : public NodeVisitor, public NodeTraverser {
void visit(const Number &node) override { add_byte( 1); add_double(node.value()); }
void visit(const Symbol &node) override { add_byte( 2); add_int(node.id()); }
void visit(const String &node) override { add_byte( 3); add_hash(node.hash()); }
- void visit(const Array &node) override { add_byte( 4); add_size(node.size()); }
+ void visit(const In &node) override { add_byte( 4);
+ add_size(node.num_entries());
+ for (size_t i = 0; i < node.num_entries(); ++i) {
+ add_double(node.get_entry(i).get_const_value());
+ }
+ }
void visit(const Neg &) override { add_byte( 5); }
void visit(const Not &) override { add_byte( 6); }
void visit(const If &node) override { add_byte( 7); add_double(node.p_true()); }
@@ -49,7 +54,6 @@ struct KeyGen : public NodeVisitor, public NodeTraverser {
void visit(const LessEqual &) override { add_byte(30); }
void visit(const Greater &) override { add_byte(31); }
void visit(const GreaterEqual &) override { add_byte(32); }
- void visit(const In &) override { add_byte(33); }
void visit(const And &) override { add_byte(34); }
void visit(const Or &) override { add_byte(35); }
void visit(const Cos &) override { add_byte(36); }
diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
index 303833e8e6c..9355cf7a4e4 100644
--- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
+++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
@@ -56,9 +56,9 @@ namespace {
struct SetMemberHash : PluginState {
vespalib::hash_set<double> members;
- explicit SetMemberHash(const Array &array) : members(array.size() * 3) {
- for (size_t i = 0; i < array.size(); ++i) {
- members.insert(array.get(i).get_const_value());
+ explicit SetMemberHash(const In &in) : members(in.num_entries() * 3) {
+ for (size_t i = 0; i < in.num_entries(); ++i) {
+ members.insert(in.get_entry(i).get_const_value());
}
}
static bool check_membership(const PluginState *state, double value) {
@@ -252,7 +252,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
inside_forest = true;
forest_end = &node;
}
- if (check_type<Array, If, In>(node)) {
+ if (check_type<If>(node)) {
node.accept(*this);
return false;
}
@@ -355,9 +355,27 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
void visit(const String &item) override {
push_double(item.hash());
}
- void visit(const Array &item) override {
- // NB: visit not open
- push_double(item.size());
+ void visit(const In &item) override {
+ llvm::Value *lhs = pop_double();
+ if (item.num_entries() > 8) {
+ // build call to hash lookup
+ plugin_state.emplace_back(new SetMemberHash(item));
+ void *call_ptr = (void *) SetMemberHash::check_membership;
+ PluginState *state = plugin_state.back().get();
+ llvm::PointerType *funptr_t = make_check_membership_funptr_t();
+ llvm::Value *call_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)call_ptr), funptr_t, "inject_call_addr");
+ llvm::Value *ctx = builder.CreateIntToPtr(builder.getInt64((uint64_t)state), builder.getVoidTy()->getPointerTo(), "inject_ctx");
+ push(builder.CreateCall(call_fun, {ctx, lhs}, "call_check_membership"));
+ } else {
+ // build explicit code to check all set members
+ llvm::Value *found = builder.getFalse();
+ for (size_t i = 0; i < item.num_entries(); ++i) {
+ llvm::Value *elem = llvm::ConstantFP::get(builder.getDoubleTy(), item.get_entry(i).get_const_value());
+ llvm::Value *elem_eq = builder.CreateFCmpOEQ(lhs, elem, "elem_eq");
+ found = builder.CreateOr(found, elem_eq, "found");
+ }
+ push(found);
+ }
}
void visit(const Neg &) override {
llvm::Value *child = pop_double();
@@ -480,38 +498,6 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
llvm::Value *a = pop_double();
push(builder.CreateFCmpOGE(a, b, "cmp_ge_res"));
}
- void visit(const In &item) override {
- // NB: visit not open
- item.lhs().traverse(*this); // NB: recursion
- llvm::Value *lhs = pop_double();
- auto array = as<Array>(item.rhs());
- if (array) {
- if (array->is_const() && array->size() > 8) {
- // build call to hash lookup
- plugin_state.emplace_back(new SetMemberHash(*array));
- void *call_ptr = (void *) SetMemberHash::check_membership;
- PluginState *state = plugin_state.back().get();
- llvm::PointerType *funptr_t = make_check_membership_funptr_t();
- llvm::Value *call_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)call_ptr), funptr_t, "inject_call_addr");
- llvm::Value *ctx = builder.CreateIntToPtr(builder.getInt64((uint64_t)state), builder.getVoidTy()->getPointerTo(), "inject_ctx");
- push(builder.CreateCall(call_fun, {ctx, lhs}, "call_check_membership"));
- } else {
- // build explicit code to check all set members
- llvm::Value *found = builder.getFalse();
- for (size_t i = 0; i < array->size(); ++i) {
- array->get(i).traverse(*this); // NB: recursion
- llvm::Value *elem = pop_double();
- llvm::Value *elem_eq = builder.CreateFCmpOEQ(lhs, elem, "elem_eq");
- found = builder.CreateOr(found, elem_eq, "found");
- }
- push(found);
- }
- } else {
- item.rhs().traverse(*this); // NB: recursion
- llvm::Value *rhs = pop_double();
- push(builder.CreateFCmpOEQ(lhs, rhs, "rhs_eq"));
- }
- }
void visit(const And &) override {
llvm::Value *b = pop_bool();
llvm::Value *a = pop_bool();
diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp
index d995c5281c4..f86c3e1a84a 100644
--- a/eval/src/vespa/eval/eval/node_types.cpp
+++ b/eval/src/vespa/eval/eval/node_types.cpp
@@ -90,9 +90,7 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser {
void visit(const String &node) override {
bind_type(ValueType::double_type(), node);
}
- void visit(const Array &node) override {
- bind_type(ValueType::double_type(), node);
- }
+ void visit(const In &node) override { resolve_op1(node); }
void visit(const Neg &node) override { resolve_op1(node); }
void visit(const Not &node) override { resolve_op1(node); }
void visit(const If &node) override {
@@ -139,9 +137,6 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser {
void visit(const LessEqual &node) override { resolve_op2(node); }
void visit(const Greater &node) override { resolve_op2(node); }
void visit(const GreaterEqual &node) override { resolve_op2(node); }
- void visit(const In &node) override {
- bind_type(ValueType::double_type(), node);
- }
void visit(const And &node) override { resolve_op2(node); }
void visit(const Or &node) override { resolve_op2(node); }
void visit(const Cos &node) override { resolve_op1(node); }
diff --git a/eval/src/vespa/eval/eval/node_visitor.h b/eval/src/vespa/eval/eval/node_visitor.h
index 91d65ceb7ec..c5a6fd51373 100644
--- a/eval/src/vespa/eval/eval/node_visitor.h
+++ b/eval/src/vespa/eval/eval/node_visitor.h
@@ -22,7 +22,7 @@ struct NodeVisitor {
virtual void visit(const nodes::Number &) = 0;
virtual void visit(const nodes::Symbol &) = 0;
virtual void visit(const nodes::String &) = 0;
- virtual void visit(const nodes::Array &) = 0;
+ virtual void visit(const nodes::In &) = 0;
virtual void visit(const nodes::Neg &) = 0;
virtual void visit(const nodes::Not &) = 0;
virtual void visit(const nodes::If &) = 0;
@@ -50,7 +50,6 @@ struct NodeVisitor {
virtual void visit(const nodes::LessEqual &) = 0;
virtual void visit(const nodes::Greater &) = 0;
virtual void visit(const nodes::GreaterEqual &) = 0;
- virtual void visit(const nodes::In &) = 0;
virtual void visit(const nodes::And &) = 0;
virtual void visit(const nodes::Or &) = 0;
@@ -92,7 +91,7 @@ struct EmptyNodeVisitor : NodeVisitor {
void visit(const nodes::Number &) override {}
void visit(const nodes::Symbol &) override {}
void visit(const nodes::String &) override {}
- void visit(const nodes::Array &) override {}
+ void visit(const nodes::In &) override {}
void visit(const nodes::Neg &) override {}
void visit(const nodes::Not &) override {}
void visit(const nodes::If &) override {}
@@ -116,7 +115,6 @@ struct EmptyNodeVisitor : NodeVisitor {
void visit(const nodes::LessEqual &) override {}
void visit(const nodes::Greater &) override {}
void visit(const nodes::GreaterEqual &) override {}
- void visit(const nodes::In &) override {}
void visit(const nodes::And &) override {}
void visit(const nodes::Or &) override {}
void visit(const nodes::Cos &) override {}
diff --git a/eval/src/vespa/eval/eval/operator_nodes.cpp b/eval/src/vespa/eval/eval/operator_nodes.cpp
index 11817630da4..4c66268dfa2 100644
--- a/eval/src/vespa/eval/eval/operator_nodes.cpp
+++ b/eval/src/vespa/eval/eval/operator_nodes.cpp
@@ -37,23 +37,10 @@ OperatorRepo::OperatorRepo() : _map(), _max_size(0) {
add(nodes::LessEqual());
add(nodes::Greater());
add(nodes::GreaterEqual());
- add(nodes::In());
add(nodes::And());
add(nodes::Or());
}
-vespalib::string
-In::dump(DumpContext &ctx) const
-{
- vespalib::string str;
- str += "(";
- str += lhs().dump(ctx);
- str += " in ";
- str += rhs().dump(ctx);
- str += ")";
- return str;
-}
-
} // namespace vespalib::eval::nodes
} // namespace vespalib::eval
} // namespace vespalib
diff --git a/eval/src/vespa/eval/eval/operator_nodes.h b/eval/src/vespa/eval/eval/operator_nodes.h
index e4dda484d68..eafd817d42c 100644
--- a/eval/src/vespa/eval/eval/operator_nodes.h
+++ b/eval/src/vespa/eval/eval/operator_nodes.h
@@ -166,9 +166,6 @@ struct Less : OperatorHelper<Less> { Less() : Helper("<"
struct LessEqual : OperatorHelper<LessEqual> { LessEqual() : Helper("<=", 10, LEFT) {}};
struct Greater : OperatorHelper<Greater> { Greater() : Helper(">", 10, LEFT) {}};
struct GreaterEqual : OperatorHelper<GreaterEqual> { GreaterEqual() : Helper(">=", 10, LEFT) {}};
-struct In : OperatorHelper<In> { In() : Helper("in", 10, LEFT) {}
- virtual vespalib::string dump(DumpContext &ctx) const override;
-};
struct And : OperatorHelper<And> { And() : Helper("&&", 2, LEFT) {}};
struct Or : OperatorHelper<Or> { Or() : Helper("||", 1, LEFT) {}};
diff --git a/eval/src/vespa/eval/eval/test/eval_spec.cpp b/eval/src/vespa/eval/eval/test/eval_spec.cpp
index c7c0d754976..d214486cf21 100644
--- a/eval/src/vespa/eval/eval/test/eval_spec.cpp
+++ b/eval/src/vespa/eval/eval/test/eval_spec.cpp
@@ -104,12 +104,6 @@ EvalSpec::add_terminal_cases() {
add_expression({}, "10").add_case({}, 10.0);
add_expression({}, "100").add_case({}, 100.0);
add_rule({"a", -5.0, 5.0}, "a", [](double a){ return a; });
- add_expression({}, "[]").add_case({}, 0.0);
- add_expression({}, "[1]").add_case({}, 1.0);
- add_expression({}, "[1,2]").add_case({}, 2.0);
- add_expression({}, "[1,2,3]").add_case({}, 3.0);
- add_expression({}, "[3,2,1]").add_case({}, 3.0);
- add_expression({}, "[1,1,1,1,1]").add_case({}, 5.0);
add_expression({}, "\"\"").add_case({}, vespalib::hash_code(""));
add_expression({}, "\"foo\"").add_case({}, vespalib::hash_code("foo"));
add_expression({}, "\"foo bar baz\"").add_case({}, vespalib::hash_code("foo bar baz"));
@@ -277,52 +271,27 @@ EvalSpec::add_set_membership_cases()
{
add_expression({"a"}, "(a in [])")
.add_case({0.0}, 0.0)
- .add_case({1.0}, 0.0)
- .add_case({2.0}, 0.0);
-
- add_expression({"a"}, "(a in [[]])")
- .add_case({0.0}, 1.0)
- .add_case({1.0}, 0.0)
- .add_case({2.0}, 0.0);
-
- add_expression({"a"}, "(a in [[[]]])")
- .add_case({0.0}, 0.0)
- .add_case({1.0}, 1.0)
- .add_case({2.0}, 0.0);
-
- add_expression({"a", "b"}, "(a in b)")
- .add_case({my_nan, 2.0}, 0.0)
- .add_case({2.0, my_nan}, 0.0)
- .add_case({my_nan, my_nan}, 0.0)
- .add_case({1.0, 2.0}, 0.0)
- .add_case({2.0 - 1e-10, 2.0}, 0.0)
- .add_case({2.0, 2.0}, 1.0)
- .add_case({2.0 + 1e-10, 2.0}, 0.0)
- .add_case({3.0, 2.0}, 0.0);
-
- add_expression({"a", "b"}, "(a in [b])")
- .add_case({my_nan, 2.0}, 0.0)
- .add_case({2.0, my_nan}, 0.0)
- .add_case({my_nan, my_nan}, 0.0)
- .add_case({1.0, 2.0}, 0.0)
- .add_case({2.0 - 1e-10, 2.0}, 0.0)
- .add_case({2.0, 2.0}, 1.0)
- .add_case({2.0 + 1e-10, 2.0}, 0.0)
- .add_case({3.0, 2.0}, 0.0);
-
- add_expression({"a", "b"}, "(a in [[b]])")
- .add_case({1.0, 2.0}, 1.0)
- .add_case({2.0, 2.0}, 0.0);
-
- add_expression({"a", "b", "c", "d"}, "(a in [b,c,d])")
- .add_case({0.0, 10.0, 20.0, 30.0}, 0.0)
- .add_case({3.0, 10.0, 20.0, 30.0}, 0.0)
- .add_case({10.0, 10.0, 20.0, 30.0}, 1.0)
- .add_case({20.0, 10.0, 20.0, 30.0}, 1.0)
- .add_case({30.0, 10.0, 20.0, 30.0}, 1.0)
- .add_case({10.0, 30.0, 20.0, 10.0}, 1.0)
- .add_case({20.0, 30.0, 20.0, 10.0}, 1.0)
- .add_case({30.0, 30.0, 20.0, 10.0}, 1.0);
+ .add_case({1.0}, 0.0);
+
+ add_expression({"a"}, "(a in [2.0])")
+ .add_case({my_nan}, 0.0)
+ .add_case({1.0}, 0.0)
+ .add_case({2.0 - 1e-10}, 0.0)
+ .add_case({2.0}, 1.0)
+ .add_case({2.0 + 1e-10}, 0.0)
+ .add_case({3.0}, 0.0);
+
+ add_expression({"a"}, "(a in [10,20,30])")
+ .add_case({0.0}, 0.0)
+ .add_case({3.0}, 0.0)
+ .add_case({10.0}, 1.0)
+ .add_case({20.0}, 1.0)
+ .add_case({30.0}, 1.0);
+
+ add_expression({"a"}, "(a in [30,20,10])")
+ .add_case({10.0}, 1.0)
+ .add_case({20.0}, 1.0)
+ .add_case({30.0}, 1.0);
}
void
diff --git a/eval/src/vespa/eval/eval/vm_forest.cpp b/eval/src/vespa/eval/eval/vm_forest.cpp
index 4a73394e354..c456c660af7 100644
--- a/eval/src/vespa/eval/eval/vm_forest.cpp
+++ b/eval/src/vespa/eval/eval/vm_forest.cpp
@@ -128,20 +128,13 @@ void encode_in(const nodes::In &in,
std::vector<uint32_t> &model_out)
{
size_t meta_idx = model_out.size();
- auto symbol = nodes::as<nodes::Symbol>(in.lhs());
+ auto symbol = nodes::as<nodes::Symbol>(in.child());
assert(symbol);
model_out.push_back(uint32_t(symbol->id()) << 12);
- assert(in.rhs().is_const());
- auto array = nodes::as<nodes::Array>(in.rhs());
size_t set_size_idx = model_out.size();
- if (array) {
- model_out.push_back(array->size());
- for (size_t i = 0; i < array->size(); ++i) {
- encode_const(array->get(i).get_const_value(), model_out);
- }
- } else {
- model_out.push_back(1);
- encode_const(in.rhs().get_const_value(), model_out);
+ model_out.push_back(in.num_entries());
+ for (size_t i = 0; i < in.num_entries(); ++i) {
+ encode_const(in.get_entry(i).get_const_value(), model_out);
}
size_t left_idx = model_out.size();
uint32_t left_type = encode_node(left_child, model_out);