diff options
31 files changed, 830 insertions, 673 deletions
diff --git a/eval/src/apps/eval_expr/eval_expr.cpp b/eval/src/apps/eval_expr/eval_expr.cpp index afddec40e48..4e6fec926e7 100644 --- a/eval/src/apps/eval_expr/eval_expr.cpp +++ b/eval/src/apps/eval_expr/eval_expr.cpp @@ -14,12 +14,12 @@ int main(int argc, char **argv) { fprintf(stderr, " quote the expression to make it a single parameter\n"); return 1; } - Function function = Function::parse({}, argv[1]); - if (function.has_error()) { - fprintf(stderr, "expression error: %s\n", function.get_error().c_str()); + auto function = Function::parse({}, argv[1]); + if (function->has_error()) { + fprintf(stderr, "expression error: %s\n", function->get_error().c_str()); return 1; } - InterpretedFunction interpreted(SimpleTensorEngine::ref(), function, NodeTypes()); + InterpretedFunction interpreted(SimpleTensorEngine::ref(), *function, NodeTypes()); InterpretedFunction::Context ctx(interpreted); SimpleParams params({}); const Value &result = interpreted.eval(ctx, params); diff --git a/eval/src/apps/tensor_conformance/tensor_conformance.cpp b/eval/src/apps/tensor_conformance/tensor_conformance.cpp index 1a760a436b6..59fe2960dbb 100644 --- a/eval/src/apps/tensor_conformance/tensor_conformance.cpp +++ b/eval/src/apps/tensor_conformance/tensor_conformance.cpp @@ -89,20 +89,20 @@ std::vector<ValueType> get_types(const std::vector<Value::UP> ¶m_values) { } TensorSpec eval_expr(const Inspector &test, const TensorEngine &engine, bool typed) { - Function fun = Function::parse(test["expression"].asString().make_string()); + auto fun = Function::parse(test["expression"].asString().make_string()); std::vector<Value::UP> param_values; std::vector<Value::CREF> param_refs; - for (size_t i = 0; i < fun.num_params(); ++i) { - param_values.emplace_back(engine.from_spec(extract_value(test["inputs"][fun.param_name(i)]))); + for (size_t i = 0; i < fun->num_params(); ++i) { + param_values.emplace_back(engine.from_spec(extract_value(test["inputs"][fun->param_name(i)]))); param_refs.emplace_back(*param_values.back()); } - NodeTypes types = typed ? NodeTypes(fun, get_types(param_values)) : NodeTypes(); - InterpretedFunction ifun(engine, fun, types); + NodeTypes types = typed ? NodeTypes(*fun, get_types(param_values)) : NodeTypes(); + InterpretedFunction ifun(engine, *fun, types); InterpretedFunction::Context ctx(ifun); SimpleObjectParams params(param_refs); const Value &result = ifun.eval(ctx, params); if (typed) { - ASSERT_EQUAL(result.type(), types.get_type(fun.root())); + ASSERT_EQUAL(result.type(), types.get_type(fun->root())); } return engine.to_spec(result); } @@ -236,12 +236,12 @@ struct TestSpec { const auto &my_expression = test["expression"]; ASSERT_TRUE(my_expression.valid()); expression = my_expression.asString().make_string(); - Function fun = Function::parse(expression); - ASSERT_TRUE(!fun.has_error()); - ASSERT_EQUAL(fun.num_params(), test["inputs"].fields()); - for (size_t i = 0; i < fun.num_params(); ++i) { + auto fun = Function::parse(expression); + ASSERT_TRUE(!fun->has_error()); + ASSERT_EQUAL(fun->num_params(), test["inputs"].fields()); + for (size_t i = 0; i < fun->num_params(); ++i) { TEST_STATE(make_string("input #%zu", i).c_str()); - const auto &my_input = test["inputs"][fun.param_name(i)]; + const auto &my_input = test["inputs"][fun->param_name(i)]; ASSERT_TRUE(my_input.valid()); inputs.push_back(extract_value(my_input)); } diff --git a/eval/src/tests/eval/compile_cache/compile_cache_test.cpp b/eval/src/tests/eval/compile_cache/compile_cache_test.cpp index 6796d498719..01a29b9b9e8 100644 --- a/eval/src/tests/eval/compile_cache/compile_cache_test.cpp +++ b/eval/src/tests/eval/compile_cache/compile_cache_test.cpp @@ -3,57 +3,54 @@ #include <vespa/eval/eval/llvm/compile_cache.h> #include <vespa/eval/eval/key_gen.h> #include <vespa/eval/eval/test/eval_spec.h> +#include <vespa/vespalib/util/time.h> +#include <vespa/vespalib/util/threadstackexecutor.h> +#include <thread> #include <set> +using namespace vespalib; using namespace vespalib::eval; //----------------------------------------------------------------------------- TEST("require that parameter passing selection affects function key") { - EXPECT_NOT_EQUAL(gen_key(Function::parse("a+b"), PassParams::SEPARATE), - gen_key(Function::parse("a+b"), PassParams::ARRAY)); + EXPECT_NOT_EQUAL(gen_key(*Function::parse("a+b"), PassParams::SEPARATE), + gen_key(*Function::parse("a+b"), PassParams::ARRAY)); } TEST("require that the number of parameters affects function key") { - EXPECT_NOT_EQUAL(gen_key(Function::parse({"a", "b"}, "a+b"), PassParams::SEPARATE), - gen_key(Function::parse({"a", "b", "c"}, "a+b"), PassParams::SEPARATE)); - EXPECT_NOT_EQUAL(gen_key(Function::parse({"a", "b"}, "a+b"), PassParams::ARRAY), - gen_key(Function::parse({"a", "b", "c"}, "a+b"), PassParams::ARRAY)); + EXPECT_NOT_EQUAL(gen_key(*Function::parse({"a", "b"}, "a+b"), PassParams::SEPARATE), + gen_key(*Function::parse({"a", "b", "c"}, "a+b"), PassParams::SEPARATE)); + EXPECT_NOT_EQUAL(gen_key(*Function::parse({"a", "b"}, "a+b"), PassParams::ARRAY), + gen_key(*Function::parse({"a", "b", "c"}, "a+b"), PassParams::ARRAY)); } TEST("require that implicit and explicit parameters give the same function key") { - EXPECT_EQUAL(gen_key(Function::parse({"a", "b"}, "a+b"), PassParams::SEPARATE), - gen_key(Function::parse("a+b"), PassParams::SEPARATE)); - EXPECT_EQUAL(gen_key(Function::parse({"a", "b"}, "a+b"), PassParams::ARRAY), - gen_key(Function::parse("a+b"), PassParams::ARRAY)); + EXPECT_EQUAL(gen_key(*Function::parse({"a", "b"}, "a+b"), PassParams::SEPARATE), + gen_key(*Function::parse("a+b"), PassParams::SEPARATE)); + EXPECT_EQUAL(gen_key(*Function::parse({"a", "b"}, "a+b"), PassParams::ARRAY), + gen_key(*Function::parse("a+b"), PassParams::ARRAY)); } TEST("require that symbol names does not affect function key") { - EXPECT_EQUAL(gen_key(Function::parse("a+b"), PassParams::SEPARATE), - gen_key(Function::parse("x+y"), PassParams::SEPARATE)); - EXPECT_EQUAL(gen_key(Function::parse("a+b"), PassParams::ARRAY), - gen_key(Function::parse("x+y"), PassParams::ARRAY)); -} - -TEST("require that let bind names does not affect function key") { - EXPECT_EQUAL(gen_key(Function::parse("let(a,1,a+a)"), PassParams::SEPARATE), - gen_key(Function::parse("let(b,1,b+b)"), PassParams::SEPARATE)); - EXPECT_EQUAL(gen_key(Function::parse("let(a,1,a+a)"), PassParams::ARRAY), - gen_key(Function::parse("let(b,1,b+b)"), PassParams::ARRAY)); + EXPECT_EQUAL(gen_key(*Function::parse("a+b"), PassParams::SEPARATE), + gen_key(*Function::parse("x+y"), PassParams::SEPARATE)); + EXPECT_EQUAL(gen_key(*Function::parse("a+b"), PassParams::ARRAY), + gen_key(*Function::parse("x+y"), PassParams::ARRAY)); } TEST("require that different values give different function keys") { - EXPECT_NOT_EQUAL(gen_key(Function::parse("1"), PassParams::SEPARATE), - gen_key(Function::parse("2"), PassParams::SEPARATE)); - EXPECT_NOT_EQUAL(gen_key(Function::parse("1"), PassParams::ARRAY), - gen_key(Function::parse("2"), PassParams::ARRAY)); + EXPECT_NOT_EQUAL(gen_key(*Function::parse("1"), PassParams::SEPARATE), + gen_key(*Function::parse("2"), PassParams::SEPARATE)); + EXPECT_NOT_EQUAL(gen_key(*Function::parse("1"), PassParams::ARRAY), + gen_key(*Function::parse("2"), PassParams::ARRAY)); } TEST("require that different strings give different function keys") { - EXPECT_NOT_EQUAL(gen_key(Function::parse("\"a\""), PassParams::SEPARATE), - gen_key(Function::parse("\"b\""), PassParams::SEPARATE)); - EXPECT_NOT_EQUAL(gen_key(Function::parse("\"a\""), PassParams::ARRAY), - gen_key(Function::parse("\"b\""), PassParams::ARRAY)); + EXPECT_NOT_EQUAL(gen_key(*Function::parse("\"a\""), PassParams::SEPARATE), + gen_key(*Function::parse("\"b\""), PassParams::SEPARATE)); + EXPECT_NOT_EQUAL(gen_key(*Function::parse("\"a\""), PassParams::ARRAY), + gen_key(*Function::parse("\"b\""), PassParams::ARRAY)); } //----------------------------------------------------------------------------- @@ -69,11 +66,11 @@ struct CheckKeys : test::EvalSpec::EvalTest { virtual void next_expression(const std::vector<vespalib::string> ¶m_names, const vespalib::string &expression) override { - Function function = Function::parse(param_names, expression); - if (!CompiledFunction::detect_issues(function)) { - if (check_key(gen_key(function, PassParams::ARRAY)) || - check_key(gen_key(function, PassParams::SEPARATE)) || - check_key(gen_key(function, PassParams::LAZY))) + auto function = Function::parse(param_names, expression); + if (!CompiledFunction::detect_issues(*function)) { + if (check_key(gen_key(*function, PassParams::ARRAY)) || + check_key(gen_key(*function, PassParams::SEPARATE)) || + check_key(gen_key(*function, PassParams::LAZY))) { failed = true; fprintf(stderr, "key collision for: %s\n", expression.c_str()); @@ -107,33 +104,33 @@ TEST("require that cache is initially empty") { } TEST("require that unused functions are evicted from the cache") { - CompileCache::Token::UP token_a = CompileCache::compile(Function::parse("x+y"), PassParams::ARRAY); + CompileCache::Token::UP token_a = CompileCache::compile(*Function::parse("x+y"), PassParams::ARRAY); TEST_DO(verify_cache(1, 1)); token_a.reset(); TEST_DO(verify_cache(0, 0)); } TEST("require that agents can have separate functions in the cache") { - CompileCache::Token::UP token_a = CompileCache::compile(Function::parse("x+y"), PassParams::ARRAY); - CompileCache::Token::UP token_b = CompileCache::compile(Function::parse("x*y"), PassParams::ARRAY); + CompileCache::Token::UP token_a = CompileCache::compile(*Function::parse("x+y"), PassParams::ARRAY); + CompileCache::Token::UP token_b = CompileCache::compile(*Function::parse("x*y"), PassParams::ARRAY); TEST_DO(verify_cache(2, 2)); } TEST("require that agents can share functions in the cache") { - CompileCache::Token::UP token_a = CompileCache::compile(Function::parse("x+y"), PassParams::ARRAY); - CompileCache::Token::UP token_b = CompileCache::compile(Function::parse("x+y"), PassParams::ARRAY); + CompileCache::Token::UP token_a = CompileCache::compile(*Function::parse("x+y"), PassParams::ARRAY); + CompileCache::Token::UP token_b = CompileCache::compile(*Function::parse("x+y"), PassParams::ARRAY); TEST_DO(verify_cache(1, 2)); } TEST("require that cache usage works") { TEST_DO(verify_cache(0, 0)); - CompileCache::Token::UP token_a = CompileCache::compile(Function::parse("x+y"), PassParams::SEPARATE); + CompileCache::Token::UP token_a = CompileCache::compile(*Function::parse("x+y"), PassParams::SEPARATE); EXPECT_EQUAL(5.0, token_a->get().get_function<2>()(2.0, 3.0)); TEST_DO(verify_cache(1, 1)); - CompileCache::Token::UP token_b = CompileCache::compile(Function::parse("x*y"), PassParams::SEPARATE); + CompileCache::Token::UP token_b = CompileCache::compile(*Function::parse("x*y"), PassParams::SEPARATE); EXPECT_EQUAL(6.0, token_b->get().get_function<2>()(2.0, 3.0)); TEST_DO(verify_cache(2, 2)); - CompileCache::Token::UP token_c = CompileCache::compile(Function::parse("x+y"), PassParams::SEPARATE); + CompileCache::Token::UP token_c = CompileCache::compile(*Function::parse("x+y"), PassParams::SEPARATE); EXPECT_EQUAL(5.0, token_c->get().get_function<2>()(2.0, 3.0)); TEST_DO(verify_cache(2, 3)); token_a.reset(); @@ -144,6 +141,78 @@ TEST("require that cache usage works") { TEST_DO(verify_cache(0, 0)); } +struct CompileCheck : test::EvalSpec::EvalTest { + struct Entry { + CompileCache::Token::UP fun; + std::vector<double> params; + double expect; + Entry(CompileCache::Token::UP fun_in, const std::vector<double> ¶ms_in, double expect_in) + : fun(std::move(fun_in)), params(params_in), expect(expect_in) {} + }; + std::vector<Entry> list; + void next_expression(const std::vector<vespalib::string> &, + const vespalib::string &) override {} + void handle_case(const std::vector<vespalib::string> ¶m_names, + const std::vector<double> ¶m_values, + const vespalib::string &expression, + double expected_result) override + { + auto function = Function::parse(param_names, expression); + ASSERT_TRUE(!function->has_error()); + bool has_issues = CompiledFunction::detect_issues(*function); + if (!has_issues) { + list.emplace_back(CompileCache::compile(*function, PassParams::ARRAY), param_values, expected_result); + } + } + void verify() { + for (const Entry &entry: list) { + auto fun = entry.fun->get().get_function(); + if (std::isnan(entry.expect)) { + EXPECT_TRUE(std::isnan(fun(&entry.params[0]))); + } else { + EXPECT_EQUAL(fun(&entry.params[0]), entry.expect); + } + } + } +}; + +TEST_F("compile sequentially, then run all conformance tests", test::EvalSpec()) { + f1.add_all_cases(); + for (size_t i = 0; i < 4; ++i) { + CompileCheck test; + auto t0 = steady_clock::now(); + f1.each_case(test); + auto t1 = steady_clock::now(); + auto t2 = steady_clock::now(); + test.verify(); + auto t3 = steady_clock::now(); + fprintf(stderr, "sequential (run %zu): setup: %zu ms, wait: %zu ms, verify: %zu us, total: %zu ms\n", + i, count_ms(t1 - t0), count_ms(t2 - t1), count_us(t3 - t2), count_ms(t3 - t0)); + } +} + +TEST_FF("compile concurrently (8 threads), then run all conformance tests", test::EvalSpec(), TimeBomb(60)) { + f1.add_all_cases(); + ThreadStackExecutor executor(8, 256*1024); + CompileCache::attach_executor(executor); + while (executor.num_idle_workers() < 8) { + std::this_thread::sleep_for(1ms); + } + for (size_t i = 0; i < 4; ++i) { + CompileCheck test; + auto t0 = steady_clock::now(); + f1.each_case(test); + auto t1 = steady_clock::now(); + executor.sync(); + auto t2 = steady_clock::now(); + test.verify(); + auto t3 = steady_clock::now(); + fprintf(stderr, "concurrent (run %zu): setup: %zu ms, wait: %zu ms, verify: %zu us, total: %zu ms\n", + i, count_ms(t1 - t0), count_ms(t2 - t1), count_us(t3 - t2), count_ms(t3 - t0)); + } + CompileCache::detach_executor(); +} + //----------------------------------------------------------------------------- TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp index 1d5a6929083..b01c849da1e 100644 --- a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp +++ b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp @@ -18,7 +18,7 @@ std::vector<vespalib::string> params_10({"p1", "p2", "p3", "p4", "p5", "p6", "p7 const char *expr_10 = "p1 + p2 + p3 + p4 + p5 + p6 + p7 + p8 + p9 + p10"; TEST("require that separate parameter passing works") { - CompiledFunction cf_10(Function::parse(params_10, expr_10), PassParams::SEPARATE); + CompiledFunction cf_10(*Function::parse(params_10, expr_10), PassParams::SEPARATE); auto fun_10 = cf_10.get_function<10>(); EXPECT_EQUAL(10.0, fun_10(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)); EXPECT_EQUAL(50.0, fun_10(5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0)); @@ -27,7 +27,7 @@ TEST("require that separate parameter passing works") { } TEST("require that array parameter passing works") { - CompiledFunction arr_cf(Function::parse(params_10, expr_10), PassParams::ARRAY); + CompiledFunction arr_cf(*Function::parse(params_10, expr_10), PassParams::ARRAY); auto arr_fun = arr_cf.get_function(); EXPECT_EQUAL(10.0, arr_fun(&std::vector<double>({1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0})[0])); EXPECT_EQUAL(50.0, arr_fun(&std::vector<double>({5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0})[0])); @@ -38,7 +38,7 @@ TEST("require that array parameter passing works") { double my_resolve(void *ctx, size_t idx) { return ((double *)ctx)[idx]; } TEST("require that lazy parameter passing works") { - CompiledFunction lazy_cf(Function::parse(params_10, expr_10), PassParams::LAZY); + CompiledFunction lazy_cf(*Function::parse(params_10, expr_10), PassParams::LAZY); auto lazy_fun = lazy_cf.get_lazy_function(); EXPECT_EQUAL(10.0, lazy_fun(my_resolve, &std::vector<double>({1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0})[0])); EXPECT_EQUAL(50.0, lazy_fun(my_resolve, &std::vector<double>({5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0})[0])); @@ -79,10 +79,10 @@ struct MyEvalTest : test::EvalSpec::EvalTest { virtual void next_expression(const std::vector<vespalib::string> ¶m_names, const vespalib::string &expression) override { - Function function = Function::parse(param_names, expression); - ASSERT_TRUE(!function.has_error()); + auto function = Function::parse(param_names, expression); + ASSERT_TRUE(!function->has_error()); bool is_supported = !is_unsupported(expression); - bool has_issues = CompiledFunction::detect_issues(function); + bool has_issues = CompiledFunction::detect_issues(*function); if (is_supported == has_issues) { const char *supported_str = is_supported ? "supported" : "not supported"; const char *issues_str = has_issues ? "has issues" : "does not have issues"; @@ -96,12 +96,12 @@ struct MyEvalTest : test::EvalSpec::EvalTest { const vespalib::string &expression, double expected_result) override { - Function function = Function::parse(param_names, expression); - ASSERT_TRUE(!function.has_error()); + auto function = Function::parse(param_names, expression); + ASSERT_TRUE(!function->has_error()); bool is_supported = !is_unsupported(expression); - bool has_issues = CompiledFunction::detect_issues(function); + bool has_issues = CompiledFunction::detect_issues(*function); if (is_supported && !has_issues) { - CompiledFunction cfun(function, PassParams::ARRAY); + CompiledFunction cfun(*function, PassParams::ARRAY); auto fun = cfun.get_function(); ASSERT_EQUAL(cfun.num_params(), param_values.size()); double result = fun(¶m_values[0]); @@ -135,9 +135,9 @@ TEST("require that large (plugin) set membership checks work") { for(size_t i = 1; i <= 100; ++i) { my_in->add_entry(std::make_unique<nodes::Number>(i)); } - Function my_fun(std::move(my_in), {"a"}); - CompiledFunction cf(my_fun, PassParams::SEPARATE); - CompiledFunction arr_cf(my_fun, PassParams::ARRAY); + auto my_fun = Function::create(std::move(my_in), {"a"}); + CompiledFunction cf(*my_fun, PassParams::SEPARATE); + CompiledFunction arr_cf(*my_fun, PassParams::ARRAY); auto fun = cf.get_function<1>(); auto arr_fun = arr_cf.get_function(); for (double value = 0.5; value <= 100.5; value += 0.5) { @@ -160,7 +160,7 @@ CompiledFunction pass_fun(CompiledFunction cf) { } TEST("require that compiled expression can be passed (moved) around") { - CompiledFunction cf(Function::parse("a+b"), PassParams::SEPARATE); + CompiledFunction cf(*Function::parse("a+b"), PassParams::SEPARATE); auto fun = cf.get_function<2>(); EXPECT_EQUAL(4.0, fun(2.0, 2.0)); CompiledFunction cf2 = pass_fun(std::move(cf)); @@ -171,16 +171,16 @@ TEST("require that compiled expression can be passed (moved) around") { } TEST("require that expressions with constant sub-expressions evaluate correctly") { - CompiledFunction cf(Function::parse("if(1,2,10)+a+b+max(1,2)/1"), PassParams::SEPARATE); + CompiledFunction cf(*Function::parse("if(1,2,10)+a+b+max(1,2)/1"), PassParams::SEPARATE); auto fun = cf.get_function<2>(); EXPECT_EQUAL(7.0, fun(1.0, 2.0)); EXPECT_EQUAL(11.0, fun(3.0, 4.0)); } TEST("dump ir code to verify lazy casting") { - Function function = Function::parse({"a", "b"}, "12==2+if(a==3&&a<10||b,10,5)"); + auto function = Function::parse({"a", "b"}, "12==2+if(a==3&&a<10||b,10,5)"); LLVMWrapper wrapper; - size_t id = wrapper.make_function(function.num_params(), PassParams::SEPARATE, function.root(), {}); + size_t id = wrapper.make_function(function->num_params(), PassParams::SEPARATE, function->root(), {}); wrapper.compile(llvm::dbgs()); // dump module before compiling it using fun_type = double (*)(double, double); fun_type fun = (fun_type) wrapper.get_function_address(id); @@ -189,30 +189,32 @@ TEST("dump ir code to verify lazy casting") { EXPECT_EQUAL(1.0, fun(3.0, 0.0)); } -TEST_MT("require that multithreaded compilation works", 64) { - { - CompiledFunction cf(Function::parse({"x", "y", "z", "w"}, "((x+1)*(y-1))/((z+1)/(w-1))"), - PassParams::SEPARATE); - auto fun = cf.get_function<4>(); - EXPECT_EQUAL(1.0, fun(0.0, 2.0, 0.0, 2.0)); - } - { - CompiledFunction cf(Function::parse({"x", "y", "z", "w"}, "((x+1)*(y-1))/((z+1)/(w-1))"), - PassParams::SEPARATE); - auto fun = cf.get_function<4>(); - EXPECT_EQUAL(4.0, fun(1.0, 3.0, 0.0, 2.0)); - } - { - CompiledFunction cf(Function::parse({"x", "y", "z", "w"}, "((x+1)*(y-1))/((z+1)/(w-1))"), - PassParams::SEPARATE); - auto fun = cf.get_function<4>(); - EXPECT_EQUAL(2.0, fun(1.0, 3.0, 1.0, 2.0)); - } - { - CompiledFunction cf(Function::parse({"x", "y", "z", "w"}, "((x+1)*(y-1))/((z+1)/(w-1))"), - PassParams::SEPARATE); - auto fun = cf.get_function<4>(); - EXPECT_EQUAL(8.0, fun(1.0, 3.0, 1.0, 5.0)); +TEST_MT("require that multithreaded compilation works", 32) { + for (size_t i = 0; i < 16; ++i) { + { + CompiledFunction cf(*Function::parse({"x", "y", "z", "w"}, "((x+1)*(y-1))/((z+1)/(w-1))"), + PassParams::SEPARATE); + auto fun = cf.get_function<4>(); + EXPECT_EQUAL(1.0, fun(0.0, 2.0, 0.0, 2.0)); + } + { + CompiledFunction cf(*Function::parse({"x", "y", "z", "w"}, "((x+1)*(y-1))/((z+1)/(w-1))"), + PassParams::SEPARATE); + auto fun = cf.get_function<4>(); + EXPECT_EQUAL(4.0, fun(1.0, 3.0, 0.0, 2.0)); + } + { + CompiledFunction cf(*Function::parse({"x", "y", "z", "w"}, "((x+1)*(y-1))/((z+1)/(w-1))"), + PassParams::SEPARATE); + auto fun = cf.get_function<4>(); + EXPECT_EQUAL(2.0, fun(1.0, 3.0, 1.0, 2.0)); + } + { + CompiledFunction cf(*Function::parse({"x", "y", "z", "w"}, "((x+1)*(y-1))/((z+1)/(w-1))"), + PassParams::SEPARATE); + auto fun = cf.get_function<4>(); + EXPECT_EQUAL(8.0, fun(1.0, 3.0, 1.0, 5.0)); + } } } @@ -221,12 +223,12 @@ TEST_MT("require that multithreaded compilation works", 64) { TEST("require that function issues can be detected") { auto simple = Function::parse("a+b"); auto complex = Function::parse("join(a,b,f(a,b)(a+b))"); - EXPECT_FALSE(simple.has_error()); - EXPECT_FALSE(complex.has_error()); - EXPECT_FALSE(CompiledFunction::detect_issues(simple)); - EXPECT_TRUE(CompiledFunction::detect_issues(complex)); + EXPECT_FALSE(simple->has_error()); + EXPECT_FALSE(complex->has_error()); + EXPECT_FALSE(CompiledFunction::detect_issues(*simple)); + EXPECT_TRUE(CompiledFunction::detect_issues(*complex)); std::cerr << "Example function issues:" << std::endl - << CompiledFunction::detect_issues(complex).list + << CompiledFunction::detect_issues(*complex).list << std::endl; } diff --git a/eval/src/tests/eval/function/function_test.cpp b/eval/src/tests/eval/function/function_test.cpp index 7cfe54b4b95..fef0ed35668 100644 --- a/eval/src/tests/eval/function/function_test.cpp +++ b/eval/src/tests/eval/function/function_test.cpp @@ -70,59 +70,59 @@ void verify_operator_binding_order(std::initializer_list<OperatorLayer> layers) bool verify_string(const vespalib::string &str, const vespalib::string &expr) { bool ok = true; - ok &= EXPECT_EQUAL(str, as_string(Function::parse(params, expr))); - ok &= EXPECT_EQUAL(expr, Function::parse(params, expr).dump()); + ok &= EXPECT_EQUAL(str, as_string(*Function::parse(params, expr))); + ok &= EXPECT_EQUAL(expr, Function::parse(params, expr)->dump()); return ok; } void verify_error(const vespalib::string &expr, const vespalib::string &expected_error) { - Function function = Function::parse(params, expr); - EXPECT_TRUE(function.has_error()); - EXPECT_EQUAL(expected_error, function.get_error()); + auto function = Function::parse(params, expr); + EXPECT_TRUE(function->has_error()); + EXPECT_EQUAL(expected_error, function->get_error()); } void verify_parse(const vespalib::string &expr, const vespalib::string &expect) { - Function function = Function::parse(expr); - EXPECT_TRUE(!function.has_error()); - EXPECT_EQUAL(function.dump_as_lambda(), expect); + auto function = Function::parse(expr); + EXPECT_TRUE(!function->has_error()); + EXPECT_EQUAL(function->dump_as_lambda(), expect); } TEST("require that scientific numbers can be parsed") { - EXPECT_EQUAL(1.0, as_number(Function::parse(params, "1"))); - EXPECT_EQUAL(2.5, as_number(Function::parse(params, "2.5"))); - EXPECT_EQUAL(100.0, as_number(Function::parse(params, "100"))); - EXPECT_EQUAL(0.01, as_number(Function::parse(params, "0.01"))); - EXPECT_EQUAL(1.05e5, as_number(Function::parse(params, "1.05e5"))); - EXPECT_EQUAL(3e7, as_number(Function::parse(params, "3e7"))); - EXPECT_EQUAL(1.05e5, as_number(Function::parse(params, "1.05e+5"))); - EXPECT_EQUAL(3e7, as_number(Function::parse(params, "3e+7"))); - EXPECT_EQUAL(1.05e-5, as_number(Function::parse(params, "1.05e-5"))); - EXPECT_EQUAL(3e-7, as_number(Function::parse(params, "3e-7"))); - EXPECT_EQUAL(1.05e5, as_number(Function::parse(params, "1.05E5"))); - EXPECT_EQUAL(3e7, as_number(Function::parse(params, "3E7"))); - EXPECT_EQUAL(1.05e5, as_number(Function::parse(params, "1.05E+5"))); - EXPECT_EQUAL(3e7, as_number(Function::parse(params, "3E+7"))); - EXPECT_EQUAL(1.05e-5, as_number(Function::parse(params, "1.05E-5"))); - EXPECT_EQUAL(3e-7, as_number(Function::parse(params, "3E-7"))); + EXPECT_EQUAL(1.0, as_number(*Function::parse(params, "1"))); + EXPECT_EQUAL(2.5, as_number(*Function::parse(params, "2.5"))); + EXPECT_EQUAL(100.0, as_number(*Function::parse(params, "100"))); + EXPECT_EQUAL(0.01, as_number(*Function::parse(params, "0.01"))); + EXPECT_EQUAL(1.05e5, as_number(*Function::parse(params, "1.05e5"))); + EXPECT_EQUAL(3e7, as_number(*Function::parse(params, "3e7"))); + EXPECT_EQUAL(1.05e5, as_number(*Function::parse(params, "1.05e+5"))); + EXPECT_EQUAL(3e7, as_number(*Function::parse(params, "3e+7"))); + EXPECT_EQUAL(1.05e-5, as_number(*Function::parse(params, "1.05e-5"))); + EXPECT_EQUAL(3e-7, as_number(*Function::parse(params, "3e-7"))); + EXPECT_EQUAL(1.05e5, as_number(*Function::parse(params, "1.05E5"))); + EXPECT_EQUAL(3e7, as_number(*Function::parse(params, "3E7"))); + EXPECT_EQUAL(1.05e5, as_number(*Function::parse(params, "1.05E+5"))); + EXPECT_EQUAL(3e7, as_number(*Function::parse(params, "3E+7"))); + EXPECT_EQUAL(1.05e-5, as_number(*Function::parse(params, "1.05E-5"))); + EXPECT_EQUAL(3e-7, as_number(*Function::parse(params, "3E-7"))); } TEST("require that number parsing does not eat +/- operators") { - EXPECT_EQUAL("(((1+2)+3)+4)", Function::parse(params, "1+2+3+4").dump()); - EXPECT_EQUAL("(((1-2)-3)-4)", Function::parse(params, "1-2-3-4").dump()); - EXPECT_EQUAL("(((1+x)+3)+y)", Function::parse(params, "1+x+3+y").dump()); - EXPECT_EQUAL("(((1-x)-3)-y)", Function::parse(params, "1-x-3-y").dump()); + EXPECT_EQUAL("(((1+2)+3)+4)", Function::parse(params, "1+2+3+4")->dump()); + EXPECT_EQUAL("(((1-2)-3)-4)", Function::parse(params, "1-2-3-4")->dump()); + EXPECT_EQUAL("(((1+x)+3)+y)", Function::parse(params, "1+x+3+y")->dump()); + EXPECT_EQUAL("(((1-x)-3)-y)", Function::parse(params, "1-x-3-y")->dump()); } TEST("require that symbols can be parsed") { - EXPECT_EQUAL("x", Function::parse(params, "x").dump()); - EXPECT_EQUAL("y", Function::parse(params, "y").dump()); - EXPECT_EQUAL("z", Function::parse(params, "z").dump()); + EXPECT_EQUAL("x", Function::parse(params, "x")->dump()); + EXPECT_EQUAL("y", Function::parse(params, "y")->dump()); + EXPECT_EQUAL("z", Function::parse(params, "z")->dump()); } TEST("require that parenthesis can be parsed") { - EXPECT_EQUAL("x", Function::parse(params, "(x)").dump()); - EXPECT_EQUAL("x", Function::parse(params, "((x))").dump()); - EXPECT_EQUAL("x", Function::parse(params, "(((x)))").dump()); + EXPECT_EQUAL("x", Function::parse(params, "(x)")->dump()); + EXPECT_EQUAL("x", Function::parse(params, "((x))")->dump()); + EXPECT_EQUAL("x", Function::parse(params, "(((x)))")->dump()); } TEST("require that strings are parsed and dumped correctly") { @@ -139,31 +139,31 @@ TEST("require that strings are parsed and dumped correctly") { vespalib::string raw_expr = vespalib::make_string("\"%c\"", c); vespalib::string hex_expr = vespalib::make_string("\"\\x%02x\"", c); vespalib::string raw_str = vespalib::make_string("%c", c); - EXPECT_EQUAL(raw_str, as_string(Function::parse(params, hex_expr))); + EXPECT_EQUAL(raw_str, as_string(*Function::parse(params, hex_expr))); if (c != 0 && c != '\"' && c != '\\') { - EXPECT_EQUAL(raw_str, as_string(Function::parse(params, raw_expr))); + EXPECT_EQUAL(raw_str, as_string(*Function::parse(params, raw_expr))); } else { - EXPECT_TRUE(Function::parse(params, raw_expr).has_error()); + EXPECT_TRUE(Function::parse(params, raw_expr)->has_error()); } if (c == '\\') { - EXPECT_EQUAL("\"\\\\\"", Function::parse(params, hex_expr).dump()); + EXPECT_EQUAL("\"\\\\\"", Function::parse(params, hex_expr)->dump()); } else if (c == '\"') { - EXPECT_EQUAL("\"\\\"\"", Function::parse(params, hex_expr).dump()); + EXPECT_EQUAL("\"\\\"\"", Function::parse(params, hex_expr)->dump()); } else if (c == '\t') { - EXPECT_EQUAL("\"\\t\"", Function::parse(params, hex_expr).dump()); + EXPECT_EQUAL("\"\\t\"", Function::parse(params, hex_expr)->dump()); } else if (c == '\n') { - EXPECT_EQUAL("\"\\n\"", Function::parse(params, hex_expr).dump()); + EXPECT_EQUAL("\"\\n\"", Function::parse(params, hex_expr)->dump()); } else if (c == '\r') { - EXPECT_EQUAL("\"\\r\"", Function::parse(params, hex_expr).dump()); + EXPECT_EQUAL("\"\\r\"", Function::parse(params, hex_expr)->dump()); } else if (c == '\f') { - EXPECT_EQUAL("\"\\f\"", Function::parse(params, hex_expr).dump()); + EXPECT_EQUAL("\"\\f\"", Function::parse(params, hex_expr)->dump()); } else if ((c >= 32) && (c <= 126)) { if (c >= 'a' && c <= 'z' && c != 't' && c != 'n' && c != 'r' && c != 'f') { - EXPECT_TRUE(Function::parse(params, vespalib::make_string("\"\\%c\"", c)).has_error()); + EXPECT_TRUE(Function::parse(params, vespalib::make_string("\"\\%c\"", c))->has_error()); } - EXPECT_EQUAL(raw_expr, Function::parse(params, hex_expr).dump()); + EXPECT_EQUAL(raw_expr, Function::parse(params, hex_expr)->dump()); } else { - EXPECT_EQUAL(hex_expr, Function::parse(params, hex_expr).dump()); + EXPECT_EQUAL(hex_expr, Function::parse(params, hex_expr)->dump()); } } } @@ -173,36 +173,36 @@ TEST("require that free arrays cannot be parsed") { } TEST("require that negative values can be parsed") { - EXPECT_EQUAL("-1", Function::parse(params, "-1").dump()); - EXPECT_EQUAL("1", Function::parse(params, "--1").dump()); - EXPECT_EQUAL("-1", Function::parse(params, " ( - ( - ( - ( (1) ) ) ) )").dump()); - EXPECT_EQUAL("-2.5", Function::parse(params, "-2.5").dump()); - EXPECT_EQUAL("-100", Function::parse(params, "-100").dump()); + EXPECT_EQUAL("-1", Function::parse(params, "-1")->dump()); + EXPECT_EQUAL("1", Function::parse(params, "--1")->dump()); + EXPECT_EQUAL("-1", Function::parse(params, " ( - ( - ( - ( (1) ) ) ) )")->dump()); + EXPECT_EQUAL("-2.5", Function::parse(params, "-2.5")->dump()); + EXPECT_EQUAL("-100", Function::parse(params, "-100")->dump()); } TEST("require that negative symbols can be parsed") { - EXPECT_EQUAL("(-x)", Function::parse(params, "-x").dump()); - EXPECT_EQUAL("(-y)", Function::parse(params, "-y").dump()); - EXPECT_EQUAL("(-z)", Function::parse(params, "-z").dump()); - EXPECT_EQUAL("(-(-(-x)))", Function::parse(params, "---x").dump()); + EXPECT_EQUAL("(-x)", Function::parse(params, "-x")->dump()); + EXPECT_EQUAL("(-y)", Function::parse(params, "-y")->dump()); + EXPECT_EQUAL("(-z)", Function::parse(params, "-z")->dump()); + EXPECT_EQUAL("(-(-(-x)))", Function::parse(params, "---x")->dump()); } TEST("require that not can be parsed") { - EXPECT_EQUAL("(!x)", Function::parse(params, "!x").dump()); - EXPECT_EQUAL("(!(!x))", Function::parse(params, "!!x").dump()); - EXPECT_EQUAL("(!(!(!x)))", Function::parse(params, "!!!x").dump()); + EXPECT_EQUAL("(!x)", Function::parse(params, "!x")->dump()); + EXPECT_EQUAL("(!(!x))", Function::parse(params, "!!x")->dump()); + EXPECT_EQUAL("(!(!(!x)))", Function::parse(params, "!!!x")->dump()); } TEST("require that not/neg binds to next value") { - EXPECT_EQUAL("((!(!(-(-x))))^z)", Function::parse(params, "!!--x^z").dump()); - EXPECT_EQUAL("((-(-(!(!x))))^z)", Function::parse(params, "--!!x^z").dump()); - EXPECT_EQUAL("((!(-(-(!x))))^z)", Function::parse(params, "!--!x^z").dump()); - EXPECT_EQUAL("((-(!(!(-x))))^z)", Function::parse(params, "-!!-x^z").dump()); + EXPECT_EQUAL("((!(!(-(-x))))^z)", Function::parse(params, "!!--x^z")->dump()); + EXPECT_EQUAL("((-(-(!(!x))))^z)", Function::parse(params, "--!!x^z")->dump()); + EXPECT_EQUAL("((!(-(-(!x))))^z)", Function::parse(params, "!--!x^z")->dump()); + EXPECT_EQUAL("((-(!(!(-x))))^z)", Function::parse(params, "-!!-x^z")->dump()); } TEST("require that parenthesis resolves before not/neg") { - EXPECT_EQUAL("(!(x^z))", Function::parse(params, "!(x^z)").dump()); - EXPECT_EQUAL("(-(x^z))", Function::parse(params, "-(x^z)").dump()); + EXPECT_EQUAL("(!(x^z))", Function::parse(params, "!(x^z)")->dump()); + EXPECT_EQUAL("(-(x^z))", Function::parse(params, "-(x^z)")->dump()); } TEST("require that operators have appropriate binding order") { @@ -216,48 +216,48 @@ TEST("require that operators have appropriate binding order") { TEST("require that operators binding left are calculated left to right") { EXPECT_TRUE(create_op("+")->order() == Operator::Order::LEFT); - EXPECT_EQUAL("((x+y)+z)", Function::parse(params, "x+y+z").dump()); + EXPECT_EQUAL("((x+y)+z)", Function::parse(params, "x+y+z")->dump()); } TEST("require that operators binding right are calculated right to left") { EXPECT_TRUE(create_op("^")->order() == Operator::Order::RIGHT); - EXPECT_EQUAL("(x^(y^z))", Function::parse(params, "x^y^z").dump()); + EXPECT_EQUAL("(x^(y^z))", Function::parse(params, "x^y^z")->dump()); } TEST("require that operators with higher precedence are resolved first") { EXPECT_TRUE(create_op("*")->priority() > create_op("+")->priority()); - EXPECT_EQUAL("(x+(y*z))", Function::parse(params, "x+y*z").dump()); - EXPECT_EQUAL("((x*y)+z)", Function::parse(params, "x*y+z").dump()); + EXPECT_EQUAL("(x+(y*z))", Function::parse(params, "x+y*z")->dump()); + EXPECT_EQUAL("((x*y)+z)", Function::parse(params, "x*y+z")->dump()); } TEST("require that multi-level operator precedence resolving works") { EXPECT_TRUE(create_op("^")->priority() > create_op("*")->priority()); EXPECT_TRUE(create_op("*")->priority() > create_op("+")->priority()); - EXPECT_EQUAL("(x+(y*(z^w)))", Function::parse(params, "x+y*z^w").dump()); - EXPECT_EQUAL("(x+((y^z)*w))", Function::parse(params, "x+y^z*w").dump()); - EXPECT_EQUAL("((x*y)+(z^w))", Function::parse(params, "x*y+z^w").dump()); - EXPECT_EQUAL("((x*(y^z))+w)", Function::parse(params, "x*y^z+w").dump()); - EXPECT_EQUAL("((x^y)+(z*w))", Function::parse(params, "x^y+z*w").dump()); - EXPECT_EQUAL("(((x^y)*z)+w)", Function::parse(params, "x^y*z+w").dump()); + EXPECT_EQUAL("(x+(y*(z^w)))", Function::parse(params, "x+y*z^w")->dump()); + EXPECT_EQUAL("(x+((y^z)*w))", Function::parse(params, "x+y^z*w")->dump()); + EXPECT_EQUAL("((x*y)+(z^w))", Function::parse(params, "x*y+z^w")->dump()); + EXPECT_EQUAL("((x*(y^z))+w)", Function::parse(params, "x*y^z+w")->dump()); + EXPECT_EQUAL("((x^y)+(z*w))", Function::parse(params, "x^y+z*w")->dump()); + EXPECT_EQUAL("(((x^y)*z)+w)", Function::parse(params, "x^y*z+w")->dump()); } TEST("require that expressions are combined when parenthesis are closed") { - EXPECT_EQUAL("((x+(y+z))+w)", Function::parse(params, "x+(y+z)+w").dump()); + EXPECT_EQUAL("((x+(y+z))+w)", Function::parse(params, "x+(y+z)+w")->dump()); } TEST("require that operators can not bind out of parenthesis") { EXPECT_TRUE(create_op("*")->priority() > create_op("+")->priority()); - EXPECT_EQUAL("((x+y)*(x+z))", Function::parse(params, "(x+y)*(x+z)").dump()); + EXPECT_EQUAL("((x+y)*(x+z))", Function::parse(params, "(x+y)*(x+z)")->dump()); } TEST("require that set membership constructs can be parsed") { - EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "x in [1,2,3]").dump()); - EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "x in [ 1 , 2 , 3 ] ").dump()); - EXPECT_EQUAL("(x in [-1,-2,-3])", Function::parse(params, "x in [-1,-2,-3]").dump()); - EXPECT_EQUAL("(x in [-1,-2,-3])", Function::parse(params, "x in [ - 1 , - 2 , - 3 ]").dump()); - EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "x in[1,2,3]").dump()); - EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "(x)in[1,2,3]").dump()); - EXPECT_EQUAL("(x in [\"a\",2,\"c\"])", Function::parse(params, "x in [\"a\",2,\"c\"]").dump()); + EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "x in [1,2,3]")->dump()); + EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "x in [ 1 , 2 , 3 ] ")->dump()); + EXPECT_EQUAL("(x in [-1,-2,-3])", Function::parse(params, "x in [-1,-2,-3]")->dump()); + EXPECT_EQUAL("(x in [-1,-2,-3])", Function::parse(params, "x in [ - 1 , - 2 , - 3 ]")->dump()); + EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "x in[1,2,3]")->dump()); + EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "(x)in[1,2,3]")->dump()); + EXPECT_EQUAL("(x in [\"a\",2,\"c\"])", Function::parse(params, "x in [\"a\",2,\"c\"]")->dump()); } TEST("require that set membership entries must be array of strings/numbers") { @@ -270,88 +270,88 @@ TEST("require that set membership entries must be array of strings/numbers") { } TEST("require that set membership binds to the next value") { - EXPECT_EQUAL("((x in [1,2,3])^2)", Function::parse(params, "x in [1,2,3]^2").dump()); + EXPECT_EQUAL("((x in [1,2,3])^2)", Function::parse(params, "x in [1,2,3]^2")->dump()); } TEST("require that set membership binds to the left with appropriate precedence") { - EXPECT_EQUAL("((x<y) in [1,2,3])", Function::parse(params, "x < y in [1,2,3]").dump()); - EXPECT_EQUAL("(x&&(y in [1,2,3]))", Function::parse(params, "x && y in [1,2,3]").dump()); + EXPECT_EQUAL("((x<y) in [1,2,3])", Function::parse(params, "x < y in [1,2,3]")->dump()); + EXPECT_EQUAL("(x&&(y in [1,2,3]))", Function::parse(params, "x && y in [1,2,3]")->dump()); } TEST("require that function calls can be parsed") { - EXPECT_EQUAL("min(max(x,y),sqrt(z))", Function::parse(params, "min(max(x,y),sqrt(z))").dump()); + EXPECT_EQUAL("min(max(x,y),sqrt(z))", Function::parse(params, "min(max(x,y),sqrt(z))")->dump()); } TEST("require that if expressions can be parsed") { - EXPECT_EQUAL("if(x,y,z)", Function::parse(params, "if(x,y,z)").dump()); - EXPECT_EQUAL("if(x,y,z)", Function::parse(params, "if (x,y,z)").dump()); - EXPECT_EQUAL("if(x,y,z)", Function::parse(params, " if ( x , y , z ) ").dump()); - EXPECT_EQUAL("if(((x>1)&&(y<3)),(y+1),(z-1))", Function::parse(params, "if(x>1&&y<3,y+1,z-1)").dump()); - EXPECT_EQUAL("if(if(x,y,z),if(x,y,z),if(x,y,z))", Function::parse(params, "if(if(x,y,z),if(x,y,z),if(x,y,z))").dump()); - EXPECT_EQUAL("if(x,y,z,0.25)", Function::parse(params, "if(x,y,z,0.25)").dump()); - EXPECT_EQUAL("if(x,y,z,0.75)", Function::parse(params, "if(x,y,z,0.75)").dump()); + EXPECT_EQUAL("if(x,y,z)", Function::parse(params, "if(x,y,z)")->dump()); + EXPECT_EQUAL("if(x,y,z)", Function::parse(params, "if (x,y,z)")->dump()); + EXPECT_EQUAL("if(x,y,z)", Function::parse(params, " if ( x , y , z ) ")->dump()); + EXPECT_EQUAL("if(((x>1)&&(y<3)),(y+1),(z-1))", Function::parse(params, "if(x>1&&y<3,y+1,z-1)")->dump()); + EXPECT_EQUAL("if(if(x,y,z),if(x,y,z),if(x,y,z))", Function::parse(params, "if(if(x,y,z),if(x,y,z),if(x,y,z))")->dump()); + EXPECT_EQUAL("if(x,y,z,0.25)", Function::parse(params, "if(x,y,z,0.25)")->dump()); + EXPECT_EQUAL("if(x,y,z,0.75)", Function::parse(params, "if(x,y,z,0.75)")->dump()); } TEST("require that if probability can be inspected") { - Function fun_1 = Function::parse("if(x,y,z,0.25)"); - auto if_1 = as<If>(fun_1.root()); + auto fun_1 = Function::parse("if(x,y,z,0.25)"); + auto if_1 = as<If>(fun_1->root()); ASSERT_TRUE(if_1); EXPECT_EQUAL(0.25, if_1->p_true()); - Function fun_2 = Function::parse("if(x,y,z,0.75)"); - auto if_2 = as<If>(fun_2.root()); + auto fun_2 = Function::parse("if(x,y,z,0.75)"); + auto if_2 = as<If>(fun_2->root()); ASSERT_TRUE(if_2); EXPECT_EQUAL(0.75, if_2->p_true()); } TEST("require that symbols can be implicit") { - EXPECT_EQUAL("x", Function::parse("x").dump()); - EXPECT_EQUAL("y", Function::parse("y").dump()); - EXPECT_EQUAL("z", Function::parse("z").dump()); + EXPECT_EQUAL("x", Function::parse("x")->dump()); + EXPECT_EQUAL("y", Function::parse("y")->dump()); + EXPECT_EQUAL("z", Function::parse("z")->dump()); } TEST("require that implicit parameters are picket up left to right") { - Function fun1 = Function::parse("x+y+y"); - Function fun2 = Function::parse("y+y+x"); - EXPECT_EQUAL("((x+y)+y)", fun1.dump()); - EXPECT_EQUAL("((y+y)+x)", fun2.dump()); - ASSERT_EQUAL(2u, fun1.num_params()); - ASSERT_EQUAL(2u, fun2.num_params()); - EXPECT_EQUAL("x", fun1.param_name(0)); - EXPECT_EQUAL("x", fun2.param_name(1)); - EXPECT_EQUAL("y", fun1.param_name(1)); - EXPECT_EQUAL("y", fun2.param_name(0)); + auto fun1 = Function::parse("x+y+y"); + auto fun2 = Function::parse("y+y+x"); + EXPECT_EQUAL("((x+y)+y)", fun1->dump()); + EXPECT_EQUAL("((y+y)+x)", fun2->dump()); + ASSERT_EQUAL(2u, fun1->num_params()); + ASSERT_EQUAL(2u, fun2->num_params()); + EXPECT_EQUAL("x", fun1->param_name(0)); + EXPECT_EQUAL("x", fun2->param_name(1)); + EXPECT_EQUAL("y", fun1->param_name(1)); + EXPECT_EQUAL("y", fun2->param_name(0)); } //----------------------------------------------------------------------------- TEST("require that leaf nodes have no children") { - EXPECT_TRUE(Function::parse("123").root().is_leaf()); - EXPECT_TRUE(Function::parse("x").root().is_leaf()); - EXPECT_TRUE(Function::parse("\"abc\"").root().is_leaf()); - EXPECT_EQUAL(0u, Function::parse("123").root().num_children()); - EXPECT_EQUAL(0u, Function::parse("x").root().num_children()); - EXPECT_EQUAL(0u, Function::parse("\"abc\"").root().num_children()); + EXPECT_TRUE(Function::parse("123")->root().is_leaf()); + EXPECT_TRUE(Function::parse("x")->root().is_leaf()); + EXPECT_TRUE(Function::parse("\"abc\"")->root().is_leaf()); + EXPECT_EQUAL(0u, Function::parse("123")->root().num_children()); + EXPECT_EQUAL(0u, Function::parse("x")->root().num_children()); + EXPECT_EQUAL(0u, Function::parse("\"abc\"")->root().num_children()); } TEST("require that Neg child can be accessed") { - Function f = Function::parse("-x"); - const Node &root = f.root(); + auto f = Function::parse("-x"); + const Node &root = f->root(); EXPECT_TRUE(!root.is_leaf()); ASSERT_EQUAL(1u, root.num_children()); EXPECT_TRUE(root.get_child(0).is_param()); } TEST("require that Not child can be accessed") { - Function f = Function::parse("!1"); - const Node &root = f.root(); + auto f = Function::parse("!1"); + const Node &root = f->root(); EXPECT_TRUE(!root.is_leaf()); ASSERT_EQUAL(1u, root.num_children()); EXPECT_EQUAL(1.0, root.get_child(0).get_const_value()); } TEST("require that If children can be accessed") { - Function f = Function::parse("if(1,2,3)"); - const Node &root = f.root(); + auto f = Function::parse("if(1,2,3)"); + const Node &root = f->root(); EXPECT_TRUE(!root.is_leaf()); ASSERT_EQUAL(3u, root.num_children()); EXPECT_EQUAL(1.0, root.get_child(0).get_const_value()); @@ -360,8 +360,8 @@ TEST("require that If children can be accessed") { } TEST("require that Operator children can be accessed") { - Function f = Function::parse("1+2"); - const Node &root = f.root(); + auto f = Function::parse("1+2"); + const Node &root = f->root(); EXPECT_TRUE(!root.is_leaf()); ASSERT_EQUAL(2u, root.num_children()); EXPECT_EQUAL(1.0, root.get_child(0).get_const_value()); @@ -369,8 +369,8 @@ TEST("require that Operator children can be accessed") { } TEST("require that Call children can be accessed") { - Function f = Function::parse("max(1,2)"); - const Node &root = f.root(); + auto f = Function::parse("max(1,2)"); + const Node &root = f->root(); EXPECT_TRUE(!root.is_leaf()); ASSERT_EQUAL(2u, root.num_children()); EXPECT_EQUAL(1.0, root.get_child(0).get_const_value()); @@ -388,8 +388,8 @@ struct MyNodeHandler : public NodeHandler { size_t detach_from_root(const vespalib::string &expr) { MyNodeHandler handler; - Function function = Function::parse(expr); - nodes::Node &mutable_root = const_cast<nodes::Node&>(function.root()); + auto function = Function::parse(expr); + nodes::Node &mutable_root = const_cast<nodes::Node&>(function->root()); mutable_root.detach_children(handler); return handler.nodes.size(); } @@ -444,15 +444,15 @@ struct MyTraverser : public NodeTraverser { }; size_t verify_traversal(size_t open_true_cnt, const vespalib::string &expression) { - Function function = Function::parse(expression); - if (!EXPECT_TRUE(!function.has_error())) { - fprintf(stderr, "--> %s\n", function.get_error().c_str()); + auto function = Function::parse(expression); + if (!EXPECT_TRUE(!function->has_error())) { + fprintf(stderr, "--> %s\n", function->get_error().c_str()); } MyTraverser traverser(open_true_cnt); - function.root().traverse(traverser); + function->root().traverse(traverser); size_t offset = 0; size_t open_cnt = open_true_cnt; - traverser.verify(function.root(), offset, open_cnt); + traverser.verify(function->root(), offset, open_cnt); EXPECT_EQUAL(offset, traverser.history.size()); return offset; } @@ -476,112 +476,112 @@ TEST("require that traversal works as expected") { //----------------------------------------------------------------------------- TEST("require that node types can be checked") { - EXPECT_TRUE(nodes::check_type<nodes::Add>(Function::parse("1+2").root())); - EXPECT_TRUE(!nodes::check_type<nodes::Add>(Function::parse("1-2").root())); - EXPECT_TRUE(!nodes::check_type<nodes::Add>(Function::parse("1*2").root())); - EXPECT_TRUE(!nodes::check_type<nodes::Add>(Function::parse("1/2").root())); - EXPECT_TRUE((nodes::check_type<nodes::Add, nodes::Sub, nodes::Mul>(Function::parse("1+2").root()))); - EXPECT_TRUE((nodes::check_type<nodes::Add, nodes::Sub, nodes::Mul>(Function::parse("1-2").root()))); - EXPECT_TRUE((nodes::check_type<nodes::Add, nodes::Sub, nodes::Mul>(Function::parse("1*2").root()))); - EXPECT_TRUE((!nodes::check_type<nodes::Add, nodes::Sub, nodes::Mul>(Function::parse("1/2").root()))); + EXPECT_TRUE(nodes::check_type<nodes::Add>(Function::parse("1+2")->root())); + EXPECT_TRUE(!nodes::check_type<nodes::Add>(Function::parse("1-2")->root())); + EXPECT_TRUE(!nodes::check_type<nodes::Add>(Function::parse("1*2")->root())); + EXPECT_TRUE(!nodes::check_type<nodes::Add>(Function::parse("1/2")->root())); + EXPECT_TRUE((nodes::check_type<nodes::Add, nodes::Sub, nodes::Mul>(Function::parse("1+2")->root()))); + EXPECT_TRUE((nodes::check_type<nodes::Add, nodes::Sub, nodes::Mul>(Function::parse("1-2")->root()))); + EXPECT_TRUE((nodes::check_type<nodes::Add, nodes::Sub, nodes::Mul>(Function::parse("1*2")->root()))); + EXPECT_TRUE((!nodes::check_type<nodes::Add, nodes::Sub, nodes::Mul>(Function::parse("1/2")->root()))); } //----------------------------------------------------------------------------- TEST("require that parameter is param, but not const") { - EXPECT_TRUE(Function::parse("x").root().is_param()); - EXPECT_TRUE(!Function::parse("x").root().is_const()); + EXPECT_TRUE(Function::parse("x")->root().is_param()); + EXPECT_TRUE(!Function::parse("x")->root().is_const()); } TEST("require that inverted parameter is not param") { - EXPECT_TRUE(!Function::parse("-x").root().is_param()); + EXPECT_TRUE(!Function::parse("-x")->root().is_param()); } TEST("require that number is const, but not param") { - EXPECT_TRUE(Function::parse("123").root().is_const()); - EXPECT_TRUE(!Function::parse("123").root().is_param()); + EXPECT_TRUE(Function::parse("123")->root().is_const()); + EXPECT_TRUE(!Function::parse("123")->root().is_param()); } TEST("require that string is const") { - EXPECT_TRUE(Function::parse("\"x\"").root().is_const()); + EXPECT_TRUE(Function::parse("\"x\"")->root().is_const()); } TEST("require that neg is const if sub-expression is const") { - EXPECT_TRUE(Function::parse("-123").root().is_const()); - EXPECT_TRUE(!Function::parse("-x").root().is_const()); + EXPECT_TRUE(Function::parse("-123")->root().is_const()); + EXPECT_TRUE(!Function::parse("-x")->root().is_const()); } TEST("require that not is const if sub-expression is const") { - EXPECT_TRUE(Function::parse("!1").root().is_const()); - EXPECT_TRUE(!Function::parse("!x").root().is_const()); + EXPECT_TRUE(Function::parse("!1")->root().is_const()); + EXPECT_TRUE(!Function::parse("!x")->root().is_const()); } TEST("require that operators are cost if both children are const") { - EXPECT_TRUE(!Function::parse("x+y").root().is_const()); - EXPECT_TRUE(!Function::parse("1+y").root().is_const()); - EXPECT_TRUE(!Function::parse("x+2").root().is_const()); - EXPECT_TRUE(Function::parse("1+2").root().is_const()); + EXPECT_TRUE(!Function::parse("x+y")->root().is_const()); + EXPECT_TRUE(!Function::parse("1+y")->root().is_const()); + EXPECT_TRUE(!Function::parse("x+2")->root().is_const()); + EXPECT_TRUE(Function::parse("1+2")->root().is_const()); } TEST("require that set membership is never tagged as const (NB: avoids jit recursion)") { - EXPECT_TRUE(!Function::parse("x in [x,y,z]").root().is_const()); - EXPECT_TRUE(!Function::parse("1 in [x,y,z]").root().is_const()); - EXPECT_TRUE(!Function::parse("1 in [1,y,z]").root().is_const()); - EXPECT_TRUE(!Function::parse("1 in [1,2,3]").root().is_const()); + EXPECT_TRUE(!Function::parse("x in [x,y,z]")->root().is_const()); + EXPECT_TRUE(!Function::parse("1 in [x,y,z]")->root().is_const()); + EXPECT_TRUE(!Function::parse("1 in [1,y,z]")->root().is_const()); + EXPECT_TRUE(!Function::parse("1 in [1,2,3]")->root().is_const()); } TEST("require that calls are cost if all parameters are const") { - EXPECT_TRUE(!Function::parse("max(x,y)").root().is_const()); - EXPECT_TRUE(!Function::parse("max(1,y)").root().is_const()); - EXPECT_TRUE(!Function::parse("max(x,2)").root().is_const()); - EXPECT_TRUE(Function::parse("max(1,2)").root().is_const()); + EXPECT_TRUE(!Function::parse("max(x,y)")->root().is_const()); + EXPECT_TRUE(!Function::parse("max(1,y)")->root().is_const()); + EXPECT_TRUE(!Function::parse("max(x,2)")->root().is_const()); + EXPECT_TRUE(Function::parse("max(1,2)")->root().is_const()); } //----------------------------------------------------------------------------- TEST("require that feature less than constant is tree if children are trees or constants") { - EXPECT_TRUE(Function::parse("if (foo < 2, 3, 4)").root().is_tree()); - EXPECT_TRUE(Function::parse("if (foo < 2, if(bar < 3, 4, 5), 6)").root().is_tree()); - EXPECT_TRUE(Function::parse("if (foo < 2, if(bar < 3, 4, 5), if(baz < 6, 7, 8))").root().is_tree()); - EXPECT_TRUE(Function::parse("if (foo < 2, 3, if(baz < 4, 5, 6))").root().is_tree()); - EXPECT_TRUE(Function::parse("if (foo < max(1,2), 3, 4)").root().is_tree()); - EXPECT_TRUE(!Function::parse("if (2 < foo, 3, 4)").root().is_tree()); - EXPECT_TRUE(!Function::parse("if (foo < bar, 3, 4)").root().is_tree()); - EXPECT_TRUE(!Function::parse("if (1 < 2, 3, 4)").root().is_tree()); - EXPECT_TRUE(!Function::parse("if (foo <= 2, 3, 4)").root().is_tree()); - EXPECT_TRUE(!Function::parse("if (foo == 2, 3, 4)").root().is_tree()); - EXPECT_TRUE(!Function::parse("if (foo > 2, 3, 4)").root().is_tree()); - EXPECT_TRUE(!Function::parse("if (foo >= 2, 3, 4)").root().is_tree()); - EXPECT_TRUE(!Function::parse("if (foo ~= 2, 3, 4)").root().is_tree()); + EXPECT_TRUE(Function::parse("if (foo < 2, 3, 4)")->root().is_tree()); + EXPECT_TRUE(Function::parse("if (foo < 2, if(bar < 3, 4, 5), 6)")->root().is_tree()); + EXPECT_TRUE(Function::parse("if (foo < 2, if(bar < 3, 4, 5), if(baz < 6, 7, 8))")->root().is_tree()); + EXPECT_TRUE(Function::parse("if (foo < 2, 3, if(baz < 4, 5, 6))")->root().is_tree()); + EXPECT_TRUE(Function::parse("if (foo < max(1,2), 3, 4)")->root().is_tree()); + EXPECT_TRUE(!Function::parse("if (2 < foo, 3, 4)")->root().is_tree()); + EXPECT_TRUE(!Function::parse("if (foo < bar, 3, 4)")->root().is_tree()); + EXPECT_TRUE(!Function::parse("if (1 < 2, 3, 4)")->root().is_tree()); + EXPECT_TRUE(!Function::parse("if (foo <= 2, 3, 4)")->root().is_tree()); + EXPECT_TRUE(!Function::parse("if (foo == 2, 3, 4)")->root().is_tree()); + EXPECT_TRUE(!Function::parse("if (foo > 2, 3, 4)")->root().is_tree()); + EXPECT_TRUE(!Function::parse("if (foo >= 2, 3, 4)")->root().is_tree()); + EXPECT_TRUE(!Function::parse("if (foo ~= 2, 3, 4)")->root().is_tree()); } TEST("require that feature in set of constants is tree if children are trees or constants") { - EXPECT_TRUE(Function::parse("if (foo in [1, 2], 3, 4)").root().is_tree()); - EXPECT_TRUE(Function::parse("if (foo in [1, 2], if(bar < 3, 4, 5), 6)").root().is_tree()); - EXPECT_TRUE(Function::parse("if (foo in [1, 2], if(bar < 3, 4, 5), if(baz < 6, 7, 8))").root().is_tree()); - EXPECT_TRUE(Function::parse("if (foo in [1, 2], 3, if(baz < 4, 5, 6))").root().is_tree()); - EXPECT_TRUE(Function::parse("if (foo in [1, 2], min(1,3), max(1,4))").root().is_tree()); - EXPECT_TRUE(!Function::parse("if (1 in [1, 2], 3, 4)").root().is_tree()); + EXPECT_TRUE(Function::parse("if (foo in [1, 2], 3, 4)")->root().is_tree()); + EXPECT_TRUE(Function::parse("if (foo in [1, 2], if(bar < 3, 4, 5), 6)")->root().is_tree()); + EXPECT_TRUE(Function::parse("if (foo in [1, 2], if(bar < 3, 4, 5), if(baz < 6, 7, 8))")->root().is_tree()); + EXPECT_TRUE(Function::parse("if (foo in [1, 2], 3, if(baz < 4, 5, 6))")->root().is_tree()); + EXPECT_TRUE(Function::parse("if (foo in [1, 2], min(1,3), max(1,4))")->root().is_tree()); + EXPECT_TRUE(!Function::parse("if (1 in [1, 2], 3, 4)")->root().is_tree()); } TEST("require that sums of trees and forests are forests") { - EXPECT_TRUE(Function::parse("if(foo<1,2,3) + if(bar<4,5,6)").root().is_forest()); - EXPECT_TRUE(Function::parse("if(foo<1,2,3) + if(bar<4,5,6) + if(bar<7,8,9)").root().is_forest()); - EXPECT_TRUE(!Function::parse("if(foo<1,2,3)").root().is_forest()); - EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + 10").root().is_forest()); - EXPECT_TRUE(!Function::parse("10 + if(bar<4,5,6)").root().is_forest()); - EXPECT_TRUE(!Function::parse("if(foo<1,2,3) - if(bar<4,5,6)").root().is_forest()); - EXPECT_TRUE(!Function::parse("if(foo<1,2,3) * if(bar<4,5,6)").root().is_forest()); - EXPECT_TRUE(!Function::parse("if(foo<1,2,3) / if(bar<4,5,6)").root().is_forest()); - EXPECT_TRUE(!Function::parse("if(foo<1,2,3) ^ if(bar<4,5,6)").root().is_forest()); - EXPECT_TRUE(!Function::parse("if(foo<1,2,3) - if(bar<4,5,6) + if(bar<7,8,9)").root().is_forest()); - EXPECT_TRUE(!Function::parse("if(foo<1,2,3) * if(bar<4,5,6) + if(bar<7,8,9)").root().is_forest()); - EXPECT_TRUE(!Function::parse("if(foo<1,2,3) / if(bar<4,5,6) + if(bar<7,8,9)").root().is_forest()); - EXPECT_TRUE(!Function::parse("if(foo<1,2,3) ^ if(bar<4,5,6) + if(bar<7,8,9)").root().is_forest()); - EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + if(bar<4,5,6) - if(bar<7,8,9)").root().is_forest()); - EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + if(bar<4,5,6) * if(bar<7,8,9)").root().is_forest()); - EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + if(bar<4,5,6) / if(bar<7,8,9)").root().is_forest()); - EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + if(bar<4,5,6) ^ if(bar<7,8,9)").root().is_forest()); + EXPECT_TRUE(Function::parse("if(foo<1,2,3) + if(bar<4,5,6)")->root().is_forest()); + EXPECT_TRUE(Function::parse("if(foo<1,2,3) + if(bar<4,5,6) + if(bar<7,8,9)")->root().is_forest()); + EXPECT_TRUE(!Function::parse("if(foo<1,2,3)")->root().is_forest()); + EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + 10")->root().is_forest()); + EXPECT_TRUE(!Function::parse("10 + if(bar<4,5,6)")->root().is_forest()); + EXPECT_TRUE(!Function::parse("if(foo<1,2,3) - if(bar<4,5,6)")->root().is_forest()); + EXPECT_TRUE(!Function::parse("if(foo<1,2,3) * if(bar<4,5,6)")->root().is_forest()); + EXPECT_TRUE(!Function::parse("if(foo<1,2,3) / if(bar<4,5,6)")->root().is_forest()); + EXPECT_TRUE(!Function::parse("if(foo<1,2,3) ^ if(bar<4,5,6)")->root().is_forest()); + EXPECT_TRUE(!Function::parse("if(foo<1,2,3) - if(bar<4,5,6) + if(bar<7,8,9)")->root().is_forest()); + EXPECT_TRUE(!Function::parse("if(foo<1,2,3) * if(bar<4,5,6) + if(bar<7,8,9)")->root().is_forest()); + EXPECT_TRUE(!Function::parse("if(foo<1,2,3) / if(bar<4,5,6) + if(bar<7,8,9)")->root().is_forest()); + EXPECT_TRUE(!Function::parse("if(foo<1,2,3) ^ if(bar<4,5,6) + if(bar<7,8,9)")->root().is_forest()); + EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + if(bar<4,5,6) - if(bar<7,8,9)")->root().is_forest()); + EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + if(bar<4,5,6) * if(bar<7,8,9)")->root().is_forest()); + EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + if(bar<4,5,6) / if(bar<7,8,9)")->root().is_forest()); + EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + if(bar<4,5,6) ^ if(bar<7,8,9)")->root().is_forest()); } //----------------------------------------------------------------------------- @@ -657,53 +657,53 @@ struct MySymbolExtractor : SymbolExtractor { }; TEST("require that custom symbol extractor may be used") { - EXPECT_EQUAL("[x+]...[missing value]...[*y]", Function::parse(params, "x+*y").dump()); - EXPECT_EQUAL("[x+]...[missing value]...[*y]", Function::parse(params, "x+*y", MySymbolExtractor()).dump()); - EXPECT_EQUAL("[x+]...[unknown symbol: 'x+']...[*y]", Function::parse(params, "x+*y", MySymbolExtractor({'+'})).dump()); - EXPECT_EQUAL("[x+*y]...[unknown symbol: 'x+*y']...[]", Function::parse(params, "x+*y", MySymbolExtractor({'+', '*'})).dump()); + EXPECT_EQUAL("[x+]...[missing value]...[*y]", Function::parse(params, "x+*y")->dump()); + EXPECT_EQUAL("[x+]...[missing value]...[*y]", Function::parse(params, "x+*y", MySymbolExtractor())->dump()); + EXPECT_EQUAL("[x+]...[unknown symbol: 'x+']...[*y]", Function::parse(params, "x+*y", MySymbolExtractor({'+'}))->dump()); + EXPECT_EQUAL("[x+*y]...[unknown symbol: 'x+*y']...[]", Function::parse(params, "x+*y", MySymbolExtractor({'+', '*'}))->dump()); } TEST("require that unknown function works as expected with custom symbol extractor") { - EXPECT_EQUAL("[bogus(]...[unknown function: 'bogus']...[x)+y]", Function::parse(params, "bogus(x)+y").dump()); - EXPECT_EQUAL("[bogus]...[unknown symbol: 'bogus']...[(x)+y]", Function::parse(params, "bogus(x)+y", MySymbolExtractor()).dump()); - EXPECT_EQUAL("[bogus(x)]...[unknown symbol: 'bogus(x)']...[+y]", Function::parse(params, "bogus(x)+y", MySymbolExtractor({'(', ')'})).dump()); + EXPECT_EQUAL("[bogus(]...[unknown function: 'bogus']...[x)+y]", Function::parse(params, "bogus(x)+y")->dump()); + EXPECT_EQUAL("[bogus]...[unknown symbol: 'bogus']...[(x)+y]", Function::parse(params, "bogus(x)+y", MySymbolExtractor())->dump()); + EXPECT_EQUAL("[bogus(x)]...[unknown symbol: 'bogus(x)']...[+y]", Function::parse(params, "bogus(x)+y", MySymbolExtractor({'(', ')'}))->dump()); } TEST("require that unknown function that is valid parameter works as expected with custom symbol extractor") { - EXPECT_EQUAL("[z(]...[unknown function: 'z']...[x)+y]", Function::parse(params, "z(x)+y").dump()); - EXPECT_EQUAL("[z]...[invalid operator: '(']...[(x)+y]", Function::parse(params, "z(x)+y", MySymbolExtractor()).dump()); - EXPECT_EQUAL("[z(x)]...[unknown symbol: 'z(x)']...[+y]", Function::parse(params, "z(x)+y", MySymbolExtractor({'(', ')'})).dump()); + EXPECT_EQUAL("[z(]...[unknown function: 'z']...[x)+y]", Function::parse(params, "z(x)+y")->dump()); + EXPECT_EQUAL("[z]...[invalid operator: '(']...[(x)+y]", Function::parse(params, "z(x)+y", MySymbolExtractor())->dump()); + EXPECT_EQUAL("[z(x)]...[unknown symbol: 'z(x)']...[+y]", Function::parse(params, "z(x)+y", MySymbolExtractor({'(', ')'}))->dump()); } TEST("require that custom symbol extractor is not invoked for known function call") { MySymbolExtractor extractor; EXPECT_EQUAL(extractor.invoke_count, 0u); - EXPECT_EQUAL("[bogus]...[unknown symbol: 'bogus']...[(1,2)]", Function::parse(params, "bogus(1,2)", extractor).dump()); + EXPECT_EQUAL("[bogus]...[unknown symbol: 'bogus']...[(1,2)]", Function::parse(params, "bogus(1,2)", extractor)->dump()); EXPECT_EQUAL(extractor.invoke_count, 1u); - EXPECT_EQUAL("max(1,2)", Function::parse(params, "max(1,2)", extractor).dump()); + EXPECT_EQUAL("max(1,2)", Function::parse(params, "max(1,2)", extractor)->dump()); EXPECT_EQUAL(extractor.invoke_count, 1u); } //----------------------------------------------------------------------------- TEST("require that valid function does not report parse error") { - Function function = Function::parse(params, "x + y"); - EXPECT_TRUE(!function.has_error()); - EXPECT_EQUAL("", function.get_error()); + auto function = Function::parse(params, "x + y"); + EXPECT_TRUE(!function->has_error()); + EXPECT_EQUAL("", function->get_error()); } TEST("require that an invalid function with explicit paramers retain its parameters") { - Function function = Function::parse({"x", "y"}, "x & y"); - EXPECT_TRUE(function.has_error()); - ASSERT_EQUAL(2u, function.num_params()); - ASSERT_EQUAL("x", function.param_name(0)); - ASSERT_EQUAL("y", function.param_name(1)); + auto function = Function::parse({"x", "y"}, "x & y"); + EXPECT_TRUE(function->has_error()); + ASSERT_EQUAL(2u, function->num_params()); + ASSERT_EQUAL("x", function->param_name(0)); + ASSERT_EQUAL("y", function->param_name(1)); } TEST("require that an invalid function with implicit paramers has no parameters") { - Function function = Function::parse("x & y"); - EXPECT_TRUE(function.has_error()); - EXPECT_EQUAL(0u, function.num_params()); + auto function = Function::parse("x & y"); + EXPECT_TRUE(function->has_error()); + EXPECT_EQUAL(0u, function->num_params()); } TEST("require that unknown operator gives parse error") { @@ -725,23 +725,23 @@ TEST("require that missing value gives parse error") { TEST("require that tensor operations can be nested") { EXPECT_EQUAL("reduce(reduce(reduce(a,sum),sum),sum,dim)", - Function::parse("reduce(reduce(reduce(a,sum),sum),sum,dim)").dump()); + Function::parse("reduce(reduce(reduce(a,sum),sum),sum,dim)")->dump()); } //----------------------------------------------------------------------------- TEST("require that tensor map can be parsed") { - EXPECT_EQUAL("map(a,f(x)(x+1))", Function::parse("map(a,f(x)(x+1))").dump()); - EXPECT_EQUAL("map(a,f(x)(x+1))", Function::parse(" map ( a , f ( x ) ( x + 1 ) ) ").dump()); + EXPECT_EQUAL("map(a,f(x)(x+1))", Function::parse("map(a,f(x)(x+1))")->dump()); + EXPECT_EQUAL("map(a,f(x)(x+1))", Function::parse(" map ( a , f ( x ) ( x + 1 ) ) ")->dump()); } TEST("require that tensor join can be parsed") { - EXPECT_EQUAL("join(a,b,f(x,y)(x+y))", Function::parse("join(a,b,f(x,y)(x+y))").dump()); - EXPECT_EQUAL("join(a,b,f(x,y)(x+y))", Function::parse(" join ( a , b , f ( x , y ) ( x + y ) ) ").dump()); + EXPECT_EQUAL("join(a,b,f(x,y)(x+y))", Function::parse("join(a,b,f(x,y)(x+y))")->dump()); + EXPECT_EQUAL("join(a,b,f(x,y)(x+y))", Function::parse(" join ( a , b , f ( x , y ) ( x + y ) ) ")->dump()); } TEST("require that parenthesis are added around lambda expression when needed") { - EXPECT_EQUAL("f(x)(sin(x))", Function::parse("sin(x)").dump_as_lambda()); + EXPECT_EQUAL("f(x)(sin(x))", Function::parse("sin(x)")->dump_as_lambda()); } TEST("require that parse error inside a lambda fails the enclosing expression") { @@ -755,16 +755,16 @@ TEST("require that outer parameters are hidden within a lambda") { //----------------------------------------------------------------------------- TEST("require that tensor reduce can be parsed") { - EXPECT_EQUAL("reduce(x,sum,a,b)", Function::parse({"x"}, "reduce(x,sum,a,b)").dump()); - EXPECT_EQUAL("reduce(x,sum,a,b,c)", Function::parse({"x"}, "reduce(x,sum,a,b,c)").dump()); - EXPECT_EQUAL("reduce(x,sum,a,b,c)", Function::parse({"x"}, " reduce ( x , sum , a , b , c ) ").dump()); - EXPECT_EQUAL("reduce(x,sum)", Function::parse({"x"}, "reduce(x,sum)").dump()); - EXPECT_EQUAL("reduce(x,avg)", Function::parse({"x"}, "reduce(x,avg)").dump()); - EXPECT_EQUAL("reduce(x,avg)", Function::parse({"x"}, "reduce( x , avg )").dump()); - EXPECT_EQUAL("reduce(x,count)", Function::parse({"x"}, "reduce(x,count)").dump()); - EXPECT_EQUAL("reduce(x,prod)", Function::parse({"x"}, "reduce(x,prod)").dump()); - EXPECT_EQUAL("reduce(x,min)", Function::parse({"x"}, "reduce(x,min)").dump()); - EXPECT_EQUAL("reduce(x,max)", Function::parse({"x"}, "reduce(x,max)").dump()); + EXPECT_EQUAL("reduce(x,sum,a,b)", Function::parse({"x"}, "reduce(x,sum,a,b)")->dump()); + EXPECT_EQUAL("reduce(x,sum,a,b,c)", Function::parse({"x"}, "reduce(x,sum,a,b,c)")->dump()); + EXPECT_EQUAL("reduce(x,sum,a,b,c)", Function::parse({"x"}, " reduce ( x , sum , a , b , c ) ")->dump()); + EXPECT_EQUAL("reduce(x,sum)", Function::parse({"x"}, "reduce(x,sum)")->dump()); + EXPECT_EQUAL("reduce(x,avg)", Function::parse({"x"}, "reduce(x,avg)")->dump()); + EXPECT_EQUAL("reduce(x,avg)", Function::parse({"x"}, "reduce( x , avg )")->dump()); + EXPECT_EQUAL("reduce(x,count)", Function::parse({"x"}, "reduce(x,count)")->dump()); + EXPECT_EQUAL("reduce(x,prod)", Function::parse({"x"}, "reduce(x,prod)")->dump()); + EXPECT_EQUAL("reduce(x,min)", Function::parse({"x"}, "reduce(x,min)")->dump()); + EXPECT_EQUAL("reduce(x,max)", Function::parse({"x"}, "reduce(x,max)")->dump()); } TEST("require that tensor reduce with unknown aggregator fails") { @@ -778,14 +778,14 @@ TEST("require that tensor reduce with duplicate dimensions fails") { //----------------------------------------------------------------------------- TEST("require that tensor rename can be parsed") { - EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,a,b)").dump()); - EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,(a),(b))").dump()); - EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,a,(b))").dump()); - EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,(a),b)").dump()); - EXPECT_EQUAL("rename(x,(a,b),(b,a))", Function::parse({"x"}, "rename(x,(a,b),(b,a))").dump()); - EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename( x , a , b )").dump()); - EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename( x , ( a ) , ( b ) )").dump()); - EXPECT_EQUAL("rename(x,(a,b),(b,a))", Function::parse({"x"}, "rename( x , ( a , b ) , ( b , a ) )").dump()); + EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,a,b)")->dump()); + EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,(a),(b))")->dump()); + EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,a,(b))")->dump()); + EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,(a),b)")->dump()); + EXPECT_EQUAL("rename(x,(a,b),(b,a))", Function::parse({"x"}, "rename(x,(a,b),(b,a))")->dump()); + EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename( x , a , b )")->dump()); + EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename( x , ( a ) , ( b ) )")->dump()); + EXPECT_EQUAL("rename(x,(a,b),(b,a))", Function::parse({"x"}, "rename( x , ( a , b ) , ( b , a ) )")->dump()); } TEST("require that tensor rename dimension lists cannot be empty") { @@ -808,9 +808,9 @@ TEST("require that tensor rename dimension lists must have equal size") { //----------------------------------------------------------------------------- TEST("require that tensor lambda can be parsed") { - EXPECT_EQUAL("tensor(x[3]):{{x:0}:0,{x:1}:1,{x:2}:2}", Function::parse({}, "tensor(x[3])(x)").dump()); + EXPECT_EQUAL("tensor(x[3]):{{x:0}:0,{x:1}:1,{x:2}:2}", Function::parse({}, "tensor(x[3])(x)")->dump()); EXPECT_EQUAL("tensor(x[2],y[2]):{{x:0,y:0}:(0==0),{x:0,y:1}:(0==1),{x:1,y:0}:(1==0),{x:1,y:1}:(1==1)}", - Function::parse({}, " tensor ( x [ 2 ] , y [ 2 ] ) ( x == y ) ").dump()); + Function::parse({}, " tensor ( x [ 2 ] , y [ 2 ] ) ( x == y ) ")->dump()); } TEST("require that tensor lambda requires appropriate tensor type") { @@ -821,7 +821,7 @@ TEST("require that tensor lambda requires appropriate tensor type") { TEST("require that tensor lambda can use non-dimension symbols") { EXPECT_EQUAL("tensor(x[2]):{{x:0}:(0==a),{x:1}:(1==a)}", - Function::parse({"a"}, "tensor(x[2])(x==a)").dump()); + Function::parse({"a"}, "tensor(x[2])(x==a)")->dump()); } //----------------------------------------------------------------------------- @@ -832,24 +832,24 @@ TEST("require that verbose tensor create can be parsed") { auto sparse2 = Function::parse("tensor(x{}):{{x:\"a\"}:1,{x:\"b\"}:2,{x:\"c\"}:3}"); auto mixed1 = Function::parse("tensor(x{},y[2]):{{x:a,y:0}:1,{x:a,y:1}:2}"); auto mixed2 = Function::parse("tensor(x{},y[2]):{{x:\"a\",y:0}:1,{x:\"a\",y:1}:2}"); - EXPECT_EQUAL("tensor(x[3]):{{x:0}:1,{x:1}:2,{x:2}:3}", dense.dump()); - EXPECT_EQUAL("tensor(x{}):{{x:\"a\"}:1,{x:\"b\"}:2,{x:\"c\"}:3}", sparse1.dump()); - EXPECT_EQUAL("tensor(x{}):{{x:\"a\"}:1,{x:\"b\"}:2,{x:\"c\"}:3}", sparse2.dump()); - EXPECT_EQUAL("tensor(x{},y[2]):{{x:\"a\",y:0}:1,{x:\"a\",y:1}:2}", mixed1.dump()); - EXPECT_EQUAL("tensor(x{},y[2]):{{x:\"a\",y:0}:1,{x:\"a\",y:1}:2}", mixed2.dump()); + EXPECT_EQUAL("tensor(x[3]):{{x:0}:1,{x:1}:2,{x:2}:3}", dense->dump()); + EXPECT_EQUAL("tensor(x{}):{{x:\"a\"}:1,{x:\"b\"}:2,{x:\"c\"}:3}", sparse1->dump()); + EXPECT_EQUAL("tensor(x{}):{{x:\"a\"}:1,{x:\"b\"}:2,{x:\"c\"}:3}", sparse2->dump()); + EXPECT_EQUAL("tensor(x{},y[2]):{{x:\"a\",y:0}:1,{x:\"a\",y:1}:2}", mixed1->dump()); + EXPECT_EQUAL("tensor(x{},y[2]):{{x:\"a\",y:0}:1,{x:\"a\",y:1}:2}", mixed2->dump()); } TEST("require that verbose tensor create can contain expressions") { auto fun = Function::parse("tensor(x[2]):{{x:0}:1,{x:1}:2+a}"); - EXPECT_EQUAL("tensor(x[2]):{{x:0}:1,{x:1}:(2+a)}", fun.dump()); - ASSERT_EQUAL(fun.num_params(), 1u); - EXPECT_EQUAL(fun.param_name(0), "a"); + EXPECT_EQUAL("tensor(x[2]):{{x:0}:1,{x:1}:(2+a)}", fun->dump()); + ASSERT_EQUAL(fun->num_params(), 1u); + EXPECT_EQUAL(fun->param_name(0), "a"); } TEST("require that verbose tensor create handles spaces and reordering of various elements") { auto fun = Function::parse(" tensor ( y [ 2 ] , x [ 2 ] ) : { { x : 0 , y : 1 } : 2 , " "{ y : 0 , x : 0 } : 1 , { y : 0 , x : 1 } : 3 , { x : 1 , y : 1 } : 4 } "); - EXPECT_EQUAL("tensor(x[2],y[2]):{{x:0,y:0}:1,{x:0,y:1}:2,{x:1,y:0}:3,{x:1,y:1}:4}", fun.dump()); + EXPECT_EQUAL("tensor(x[2],y[2]):{{x:0,y:0}:1,{x:0,y:1}:2,{x:1,y:0}:3,{x:1,y:1}:4}", fun->dump()); } TEST("require that verbose tensor create detects invalid tensor type") { @@ -890,23 +890,23 @@ TEST("require that convenient tensor create can be parsed") { auto sparse2 = Function::parse("tensor(x{}):{\"a\":1,\"b\":2,\"c\":3}"); auto mixed1 = Function::parse("tensor(x{},y[2]):{a:[1,2]}"); auto mixed2 = Function::parse("tensor(x{},y[2]):{\"a\":[1,2]}"); - EXPECT_EQUAL("tensor(x[3]):{{x:0}:1,{x:1}:2,{x:2}:3}", dense.dump()); - EXPECT_EQUAL("tensor(x{}):{{x:\"a\"}:1,{x:\"b\"}:2,{x:\"c\"}:3}", sparse1.dump()); - EXPECT_EQUAL("tensor(x{}):{{x:\"a\"}:1,{x:\"b\"}:2,{x:\"c\"}:3}", sparse2.dump()); - EXPECT_EQUAL("tensor(x{},y[2]):{{x:\"a\",y:0}:1,{x:\"a\",y:1}:2}", mixed1.dump()); - EXPECT_EQUAL("tensor(x{},y[2]):{{x:\"a\",y:0}:1,{x:\"a\",y:1}:2}", mixed2.dump()); + EXPECT_EQUAL("tensor(x[3]):{{x:0}:1,{x:1}:2,{x:2}:3}", dense->dump()); + EXPECT_EQUAL("tensor(x{}):{{x:\"a\"}:1,{x:\"b\"}:2,{x:\"c\"}:3}", sparse1->dump()); + EXPECT_EQUAL("tensor(x{}):{{x:\"a\"}:1,{x:\"b\"}:2,{x:\"c\"}:3}", sparse2->dump()); + EXPECT_EQUAL("tensor(x{},y[2]):{{x:\"a\",y:0}:1,{x:\"a\",y:1}:2}", mixed1->dump()); + EXPECT_EQUAL("tensor(x{},y[2]):{{x:\"a\",y:0}:1,{x:\"a\",y:1}:2}", mixed2->dump()); } TEST("require that convenient tensor create can contain expressions") { auto fun = Function::parse("tensor(x[2]):[1,2+a]"); - EXPECT_EQUAL("tensor(x[2]):{{x:0}:1,{x:1}:(2+a)}", fun.dump()); - ASSERT_EQUAL(fun.num_params(), 1u); - EXPECT_EQUAL(fun.param_name(0), "a"); + EXPECT_EQUAL("tensor(x[2]):{{x:0}:1,{x:1}:(2+a)}", fun->dump()); + ASSERT_EQUAL(fun->num_params(), 1u); + EXPECT_EQUAL(fun->param_name(0), "a"); } TEST("require that convenient tensor create handles dimension order") { auto mixed = Function::parse("tensor(y{},x[2]):{a:[1,2]}"); - EXPECT_EQUAL("tensor(x[2],y{}):{{x:0,y:\"a\"}:1,{x:1,y:\"a\"}:2}", mixed.dump()); + EXPECT_EQUAL("tensor(x[2],y{}):{{x:0,y:\"a\"}:1,{x:1,y:\"a\"}:2}", mixed->dump()); } TEST("require that convenient tensor create can be highly nested") { @@ -914,9 +914,9 @@ TEST("require that convenient tensor create can be highly nested") { auto nested1 = Function::parse("tensor(a{},b{},c[1],d[1]):{x:{y:[[5]]}}"); auto nested2 = Function::parse("tensor(c[1],d[1],a{},b{}):[[{x:{y:5}}]]"); auto nested3 = Function::parse("tensor(a{},c[1],b{},d[1]): { x : [ { y : [ 5 ] } ] } "); - EXPECT_EQUAL(expect, nested1.dump()); - EXPECT_EQUAL(expect, nested2.dump()); - EXPECT_EQUAL(expect, nested3.dump()); + EXPECT_EQUAL(expect, nested1->dump()); + EXPECT_EQUAL(expect, nested2->dump()); + EXPECT_EQUAL(expect, nested3->dump()); } TEST("require that convenient tensor create can have multiple values on multiple levels") { @@ -925,15 +925,15 @@ TEST("require that convenient tensor create can have multiple values on multiple auto fun2 = Function::parse("tensor(y[2],x{}):[{a:1,b:3},{a:2,b:4}]"); auto fun3 = Function::parse("tensor(x{},y[2]): { a : [ 1 , 2 ] , b : [ 3 , 4 ] } "); auto fun4 = Function::parse("tensor(y[2],x{}): [ { a : 1 , b : 3 } , { a : 2 , b : 4 } ] "); - EXPECT_EQUAL(expect, fun1.dump()); - EXPECT_EQUAL(expect, fun2.dump()); - EXPECT_EQUAL(expect, fun3.dump()); - EXPECT_EQUAL(expect, fun4.dump()); + EXPECT_EQUAL(expect, fun1->dump()); + EXPECT_EQUAL(expect, fun2->dump()); + EXPECT_EQUAL(expect, fun3->dump()); + EXPECT_EQUAL(expect, fun4->dump()); } TEST("require that convenient tensor create allows under-specified tensors") { auto fun = Function::parse("tensor(x[2],y[2]):[[],[5]]"); - EXPECT_EQUAL("tensor(x[2],y[2]):{{x:1,y:0}:5}", fun.dump()); + EXPECT_EQUAL("tensor(x[2],y[2]):{{x:1,y:0}:5}", fun->dump()); } TEST("require that convenient tensor create detects invalid tensor type") { @@ -992,16 +992,16 @@ TEST("require that tensor peek empty label is not allowed") { TEST("require that nested tensor lambda using tensor peek can be parsed") { vespalib::string expect("tensor(x[2]):{{x:0}:tensor(y[2]):{{y:0}:((0+0)+a),{y:1}:((0+1)+a)}{y:\"0\"}," "{x:1}:tensor(y[2]):{{y:0}:((1+0)+a),{y:1}:((1+1)+a)}{y:\"1\"}}"); - EXPECT_EQUAL(Function::parse(expect).dump(), expect); + EXPECT_EQUAL(Function::parse(expect)->dump(), expect); auto fun = Function::parse("tensor(x[2])(tensor(y[2])(x+y+a){y:(x)})"); - EXPECT_EQUAL(fun.dump(), expect); + EXPECT_EQUAL(fun->dump(), expect); } //----------------------------------------------------------------------------- TEST("require that tensor concat can be parsed") { - EXPECT_EQUAL("concat(a,b,d)", Function::parse({"a", "b"}, "concat(a,b,d)").dump()); - EXPECT_EQUAL("concat(a,b,d)", Function::parse({"a", "b"}, " concat ( a , b , d ) ").dump()); + EXPECT_EQUAL("concat(a,b,d)", Function::parse({"a", "b"}, "concat(a,b,d)")->dump()); + EXPECT_EQUAL("concat(a,b,d)", Function::parse({"a", "b"}, " concat ( a , b , d ) ")->dump()); } //----------------------------------------------------------------------------- @@ -1012,10 +1012,10 @@ struct CheckExpressions : test::EvalSpec::EvalTest { virtual void next_expression(const std::vector<vespalib::string> ¶m_names, const vespalib::string &expression) override { - Function function = Function::parse(param_names, expression); - if (function.has_error()) { + auto function = Function::parse(param_names, expression); + if (function->has_error()) { failed = true; - fprintf(stderr, "parse error: %s\n", function.get_error().c_str()); + fprintf(stderr, "parse error: %s\n", function->get_error().c_str()); } ++seen_cnt; } diff --git a/eval/src/tests/eval/function_speed/function_speed_test.cpp b/eval/src/tests/eval/function_speed/function_speed_test.cpp index 178ab32d734..1295f482f76 100644 --- a/eval/src/tests/eval/function_speed/function_speed_test.cpp +++ b/eval/src/tests/eval/function_speed/function_speed_test.cpp @@ -39,7 +39,7 @@ double big_gcc_function(double p, double o, double q, double f, double w) { //----------------------------------------------------------------------------- struct Fixture { - Function function; + std::shared_ptr<Function const> function; InterpretedFunction interpreted_simple; InterpretedFunction interpreted; CompiledFunction separate; @@ -47,12 +47,12 @@ struct Fixture { CompiledFunction lazy; Fixture(const vespalib::string &expr) : function(Function::parse(expr)), - interpreted_simple(SimpleTensorEngine::ref(), function, NodeTypes()), - interpreted(DefaultTensorEngine::ref(), function, - NodeTypes(function, std::vector<ValueType>(function.num_params(), ValueType::double_type()))), - separate(function, PassParams::SEPARATE), - array(function, PassParams::ARRAY), - lazy(function, PassParams::LAZY) {} + interpreted_simple(SimpleTensorEngine::ref(), *function, NodeTypes()), + interpreted(DefaultTensorEngine::ref(), *function, + NodeTypes(*function, std::vector<ValueType>(function->num_params(), ValueType::double_type()))), + separate(*function, PassParams::SEPARATE), + array(*function, PassParams::ARRAY), + lazy(*function, PassParams::LAZY) {} }; //----------------------------------------------------------------------------- diff --git a/eval/src/tests/eval/gbdt/fast_forest_bench.cpp b/eval/src/tests/eval/gbdt/fast_forest_bench.cpp index 76a56bec50c..f63fc428f64 100644 --- a/eval/src/tests/eval/gbdt/fast_forest_bench.cpp +++ b/eval/src/tests/eval/gbdt/fast_forest_bench.cpp @@ -31,17 +31,17 @@ void run_fast_forest_bench() { for (size_t invert_percent: std::vector<size_t>({50})) { fprintf(stderr, "\n=== features: %zu, num leafs: %zu, num trees: %zu\n", max_features, tree_size, num_trees); vespalib::string expression = Model().max_features(max_features).less_percent(less_percent).invert_percent(invert_percent).make_forest(num_trees, tree_size); - Function function = Function::parse(expression); + auto function = Function::parse(expression); for (size_t min_bits = std::max(size_t(8), tree_size); true; min_bits *= 2) { - auto forest = FastForest::try_convert(function, min_bits, 64); + auto forest = FastForest::try_convert(*function, min_bits, 64); if (forest) { - estimate_cost(function.num_params(), forest->impl_name().c_str(), *forest); + estimate_cost(function->num_params(), forest->impl_name().c_str(), *forest); } if (min_bits > 64) { break; } } - estimate_cost(function.num_params(), "vm forest", CompiledFunction(function, PassParams::ARRAY, VMForest::optimize_chain)); + estimate_cost(function->num_params(), "vm forest", CompiledFunction(*function, PassParams::ARRAY, VMForest::optimize_chain)); } } } diff --git a/eval/src/tests/eval/gbdt/gbdt_benchmark.cpp b/eval/src/tests/eval/gbdt/gbdt_benchmark.cpp index 20e04c9593e..8c34d6f7d7f 100644 --- a/eval/src/tests/eval/gbdt/gbdt_benchmark.cpp +++ b/eval/src/tests/eval/gbdt/gbdt_benchmark.cpp @@ -163,10 +163,10 @@ std::vector<Option> find_order(const ForestParams ¶ms, size_t num_trees) { std::vector<Result> results; - Function forest = make_forest(params, num_trees); + auto forest = make_forest(params, num_trees); for (size_t i = 0; i < options.size(); ++i) { - CompiledFunction compiled_function = options[i].compile(forest); - CompiledFunction compiled_function_lazy = options[i].compile_lazy(forest); + CompiledFunction compiled_function = options[i].compile(*forest); + CompiledFunction compiled_function_lazy = options[i].compile_lazy(*forest); std::vector<double> inputs(compiled_function.num_params(), 0.5); results.push_back({compiled_function.estimate_cost_us(inputs, budget), i}); double lazy_time = compiled_function_lazy.estimate_cost_us(inputs, budget); @@ -184,7 +184,7 @@ std::vector<Option> find_order(const ForestParams ¶ms, } double expected_path(const ForestParams ¶ms, size_t num_trees) { - return ForestStats(extract_trees(make_forest(params, num_trees).root())).total_expected_path_length; + return ForestStats(extract_trees(make_forest(params, num_trees)->root())).total_expected_path_length; } void explore_segment(const ForestParams ¶ms, diff --git a/eval/src/tests/eval/gbdt/gbdt_test.cpp b/eval/src/tests/eval/gbdt/gbdt_test.cpp index adb3d22847a..e12ea5b9cf3 100644 --- a/eval/src/tests/eval/gbdt/gbdt_test.cpp +++ b/eval/src/tests/eval/gbdt/gbdt_test.cpp @@ -54,10 +54,10 @@ double eval_ff(const FastForest &ff, FastForest::Context &ctx, const std::vector TEST("require that tree stats can be calculated") { for (size_t tree_size = 2; tree_size < 64; ++tree_size) { - EXPECT_EQUAL(tree_size, TreeStats(Function::parse(Model().make_tree(tree_size)).root()).size); + EXPECT_EQUAL(tree_size, TreeStats(Function::parse(Model().make_tree(tree_size))->root()).size); } - TreeStats stats1(Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))").root()); + TreeStats stats1(Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))")->root()); EXPECT_EQUAL(3u, stats1.num_params); EXPECT_EQUAL(4u, stats1.size); EXPECT_EQUAL(1u, stats1.num_less_checks); @@ -65,7 +65,7 @@ TEST("require that tree stats can be calculated") { EXPECT_EQUAL(0u, stats1.num_inverted_checks); EXPECT_EQUAL(3u, stats1.max_set_size); - TreeStats stats2(Function::parse("if((d in [1]),10.0,if(!(e>=1),20.0,30.0))").root()); + TreeStats stats2(Function::parse("if((d in [1]),10.0,if(!(e>=1),20.0,30.0))")->root()); EXPECT_EQUAL(2u, stats2.num_params); EXPECT_EQUAL(3u, stats2.size); EXPECT_EQUAL(0u, stats2.num_less_checks); @@ -78,8 +78,8 @@ TEST("require that trees can be extracted from forest") { for (size_t tree_size = 10; tree_size < 20; ++tree_size) { for (size_t forest_size = 10; forest_size < 20; ++forest_size) { vespalib::string expression = Model().make_forest(forest_size, tree_size); - Function function = Function::parse(expression); - std::vector<const Node *> trees = extract_trees(function.root()); + auto function = Function::parse(expression); + std::vector<const Node *> trees = extract_trees(function->root()); EXPECT_EQUAL(forest_size, trees.size()); for (const Node *tree: trees) { EXPECT_EQUAL(tree_size, TreeStats(*tree).size); @@ -89,10 +89,10 @@ TEST("require that trees can be extracted from forest") { } TEST("require that forest stats can be calculated") { - Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+" - "if((d in [1]),10.0,if(!(e>=1),20.0,30.0))+" - "if((a<1),10.0,if(!(e>=1),20.0,30.0))"); - std::vector<const Node *> trees = extract_trees(function.root()); + auto function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+" + "if((d in [1]),10.0,if(!(e>=1),20.0,30.0))+" + "if((a<1),10.0,if(!(e>=1),20.0,30.0))"); + std::vector<const Node *> trees = extract_trees(function->root()); ForestStats stats(trees); EXPECT_EQUAL(5u, stats.num_params); EXPECT_EQUAL(3u, stats.num_trees); @@ -109,7 +109,7 @@ TEST("require that forest stats can be calculated") { } double expected_path(const vespalib::string &forest) { - return ForestStats(extract_trees(Function::parse(forest).root())).total_expected_path_length; + return ForestStats(extract_trees(Function::parse(forest)->root())).total_expected_path_length; } TEST("require that expected path length is calculated correctly") { @@ -125,7 +125,7 @@ TEST("require that expected path length is calculated correctly") { } double average_path(const vespalib::string &forest) { - return ForestStats(extract_trees(Function::parse(forest).root())).total_average_path_length; + return ForestStats(extract_trees(Function::parse(forest)->root())).total_average_path_length; } TEST("require that average path length is calculated correctly") { @@ -141,7 +141,7 @@ TEST("require that average path length is calculated correctly") { } double count_tuned(const vespalib::string &forest) { - return ForestStats(extract_trees(Function::parse(forest).root())).total_tuned_checks; + return ForestStats(extract_trees(Function::parse(forest)->root())).total_tuned_checks; } TEST("require that tuned checks are counted correctly") { @@ -204,11 +204,11 @@ struct DummyForest2 : public Forest { TEST("require that trees cannot be optimized by a forest optimizer when using SEPARATE params") { Optimize::Chain chain({DummyForest0::optimize}); - Function function = Function::parse("if((a<1),1.0,if((b<1),if((c<1),2.0,3.0),4.0))+" - "if((d<1),10.0,if((e<1),if((f<1),20.0,30.0),40.0))"); - CompiledFunction compiled_function(function, PassParams::SEPARATE, chain); - CompiledFunction compiled_function_array(function, PassParams::ARRAY, chain); - CompiledFunction compiled_function_lazy(function, PassParams::LAZY, chain); + auto function = Function::parse("if((a<1),1.0,if((b<1),if((c<1),2.0,3.0),4.0))+" + "if((d<1),10.0,if((e<1),if((f<1),20.0,30.0),40.0))"); + CompiledFunction compiled_function(*function, PassParams::SEPARATE, chain); + CompiledFunction compiled_function_array(*function, PassParams::ARRAY, chain); + CompiledFunction compiled_function_lazy(*function, PassParams::LAZY, chain); EXPECT_EQUAL(0u, compiled_function.get_forests().size()); EXPECT_EQUAL(1u, compiled_function_array.get_forests().size()); EXPECT_EQUAL(1u, compiled_function_lazy.get_forests().size()); @@ -226,12 +226,12 @@ TEST("require that trees can be optimized by a forest optimizer when using ARRAY size_t tree_size = 20; for (size_t forest_size = 10; forest_size <= 100; forest_size += 10) { vespalib::string expression = Model().make_forest(forest_size, tree_size); - Function function = Function::parse(expression); - CompiledFunction compiled_function(function, PassParams::ARRAY, chain); - std::vector<double> inputs(function.num_params(), 0.5); + auto function = Function::parse(expression); + CompiledFunction compiled_function(*function, PassParams::ARRAY, chain); + std::vector<double> inputs(function->num_params(), 0.5); if (forest_size < 25) { EXPECT_EQUAL(0u, compiled_function.get_forests().size()); - EXPECT_EQUAL(eval_double(function, inputs), compiled_function.get_function()(&inputs[0])); + EXPECT_EQUAL(eval_double(*function, inputs), compiled_function.get_function()(&inputs[0])); } else if (forest_size < 50) { EXPECT_EQUAL(1u, compiled_function.get_forests().size()); EXPECT_EQUAL(double(forest_size), compiled_function.get_function()(&inputs[0])); @@ -247,12 +247,12 @@ TEST("require that trees can be optimized by a forest optimizer when using LAZY size_t tree_size = 20; for (size_t forest_size = 10; forest_size <= 100; forest_size += 10) { vespalib::string expression = Model().make_forest(forest_size, tree_size); - Function function = Function::parse(expression); - CompiledFunction compiled_function(function, PassParams::LAZY, chain); - std::vector<double> inputs(function.num_params(), 0.5); + auto function = Function::parse(expression); + CompiledFunction compiled_function(*function, PassParams::LAZY, chain); + std::vector<double> inputs(function->num_params(), 0.5); if (forest_size < 25) { EXPECT_EQUAL(0u, compiled_function.get_forests().size()); - EXPECT_EQUAL(eval_double(function, inputs), compiled_function.get_lazy_function()(my_resolve, &inputs[0])); + EXPECT_EQUAL(eval_double(*function, inputs), compiled_function.get_lazy_function()(my_resolve, &inputs[0])); } else if (forest_size < 50) { EXPECT_EQUAL(1u, compiled_function.get_forests().size()); EXPECT_EQUAL(double(forest_size), compiled_function.get_lazy_function()(my_resolve, &inputs[0])); @@ -269,9 +269,9 @@ Optimize::Chain less_only_vm_chain({VMForest::less_only_optimize}); Optimize::Chain general_vm_chain({VMForest::general_optimize}); TEST("require that less only VM tree optimizer works") { - Function function = Function::parse("if((a<1),1.0,if((b<1),if((c<1),2.0,3.0),4.0))+" - "if((d<1),10.0,if((e<1),if((f<1),20.0,30.0),40.0))"); - CompiledFunction compiled_function(function, PassParams::ARRAY, less_only_vm_chain); + auto function = Function::parse("if((a<1),1.0,if((b<1),if((c<1),2.0,3.0),4.0))+" + "if((d<1),10.0,if((e<1),if((f<1),20.0,30.0),40.0))"); + CompiledFunction compiled_function(*function, PassParams::ARRAY, less_only_vm_chain); EXPECT_EQUAL(1u, compiled_function.get_forests().size()); auto f = compiled_function.get_function(); EXPECT_EQUAL(11.0, f(&std::vector<double>({0.5, 0.0, 0.0, 0.5, 0.0, 0.0})[0])); @@ -281,8 +281,8 @@ TEST("require that less only VM tree optimizer works") { } TEST("require that models with in checks are rejected by less only vm optimizer") { - Function function = Function::parse(Model().less_percent(100).make_forest(300, 30)); - auto trees = extract_trees(function.root()); + auto function = Function::parse(Model().less_percent(100).make_forest(300, 30)); + auto trees = extract_trees(function->root()); ForestStats stats(trees); EXPECT_TRUE(Optimize::apply_chain(less_only_vm_chain, stats, trees).valid()); stats.total_in_checks = 1; @@ -290,8 +290,8 @@ TEST("require that models with in checks are rejected by less only vm optimizer" } TEST("require that models with inverted checks are rejected by less only vm optimizer") { - Function function = Function::parse(Model().less_percent(100).make_forest(300, 30)); - auto trees = extract_trees(function.root()); + auto function = Function::parse(Model().less_percent(100).make_forest(300, 30)); + auto trees = extract_trees(function->root()); ForestStats stats(trees); EXPECT_TRUE(Optimize::apply_chain(less_only_vm_chain, stats, trees).valid()); stats.total_inverted_checks = 1; @@ -299,9 +299,9 @@ TEST("require that models with inverted checks are rejected by less only vm opti } TEST("require that general VM tree optimizer works") { - Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+" + auto function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+" "if((d in [1]),10.0,if(!(e>=1),if((f<1),20.0,30.0),40.0))"); - CompiledFunction compiled_function(function, PassParams::ARRAY, general_vm_chain); + CompiledFunction compiled_function(*function, PassParams::ARRAY, general_vm_chain); EXPECT_EQUAL(1u, compiled_function.get_forests().size()); auto f = compiled_function.get_function(); EXPECT_EQUAL(11.0, f(&std::vector<double>({0.5, 0.0, 0.0, 1.0, 0.0, 0.0})[0])); @@ -311,8 +311,8 @@ TEST("require that general VM tree optimizer works") { } TEST("require that models with too large sets are rejected by general vm optimizer") { - Function function = Function::parse(Model().less_percent(80).make_forest(300, 30)); - auto trees = extract_trees(function.root()); + auto function = Function::parse(Model().less_percent(80).make_forest(300, 30)); + auto trees = extract_trees(function->root()); ForestStats stats(trees); EXPECT_TRUE(stats.total_in_checks > 0); EXPECT_TRUE(Optimize::apply_chain(general_vm_chain, stats, trees).valid()); @@ -321,12 +321,12 @@ TEST("require that models with too large sets are rejected by general vm optimiz } TEST("require that FastForest model evaluation works") { - Function function = Function::parse("if((a<2),1.0,if((b<2),if((c<2),2.0,3.0),4.0))+" + auto function = Function::parse("if((a<2),1.0,if((b<2),if((c<2),2.0,3.0),4.0))+" "if(!(c>=1),10.0,if((a<1),if((b<1),20.0,30.0),40.0))"); - CompiledFunction compiled(function, PassParams::ARRAY, Optimize::none); + CompiledFunction compiled(*function, PassParams::ARRAY, Optimize::none); auto f = compiled.get_function(); EXPECT_TRUE(compiled.get_forests().empty()); - auto forest = FastForest::try_convert(function); + auto forest = FastForest::try_convert(*function); ASSERT_TRUE(forest); auto ctx = forest->create_context(); std::vector<double> p1({0.5, 0.5, 0.5}); // all true: 1.0 + 10.0 @@ -347,21 +347,21 @@ TEST("require that forests evaluate to approximately the same for all evaluation for (size_t less_percent: std::vector<size_t>({100, 80})) { for (size_t invert_percent: std::vector<size_t>({0, 50})) { vespalib::string expression = Model().less_percent(less_percent).invert_percent(invert_percent).make_forest(num_trees, tree_size); - Function function = Function::parse(expression); - auto forest = FastForest::try_convert(function); + auto function = Function::parse(expression); + auto forest = FastForest::try_convert(*function); EXPECT_EQUAL(bool(forest), bool(less_percent == 100)); - CompiledFunction none(function, pass_params, Optimize::none); - CompiledFunction deinline(function, pass_params, DeinlineForest::optimize_chain); - CompiledFunction vm_forest(function, pass_params, VMForest::optimize_chain); + CompiledFunction none(*function, pass_params, Optimize::none); + CompiledFunction deinline(*function, pass_params, DeinlineForest::optimize_chain); + CompiledFunction vm_forest(*function, pass_params, VMForest::optimize_chain); EXPECT_EQUAL(0u, none.get_forests().size()); ASSERT_EQUAL(1u, deinline.get_forests().size()); EXPECT_TRUE(dynamic_cast<DeinlineForest*>(deinline.get_forests()[0].get()) != nullptr); ASSERT_EQUAL(1u, vm_forest.get_forests().size()); EXPECT_TRUE(dynamic_cast<VMForest*>(vm_forest.get_forests()[0].get()) != nullptr); - std::vector<double> inputs(function.num_params(), 0.5); - std::vector<double> inputs_nan(function.num_params(), std::numeric_limits<double>::quiet_NaN()); - double expected = eval_double(function, inputs); - double expected_nan = eval_double(function, inputs_nan); + std::vector<double> inputs(function->num_params(), 0.5); + std::vector<double> inputs_nan(function->num_params(), std::numeric_limits<double>::quiet_NaN()); + double expected = eval_double(*function, inputs); + double expected_nan = eval_double(*function, inputs_nan); EXPECT_EQUAL(expected, eval_compiled(none, inputs)); EXPECT_EQUAL(expected, eval_compiled(deinline, inputs)); EXPECT_EQUAL(expected, eval_compiled(vm_forest, inputs)); @@ -387,15 +387,15 @@ TEST("require that fast forest evaluation is correct for all tree size categorie for (size_t less_percent: std::vector<size_t>({100})) { for (size_t invert_percent: std::vector<size_t>({50})) { vespalib::string expression = Model().max_features(num_features).less_percent(less_percent).invert_percent(invert_percent).make_forest(num_trees, tree_size); - Function function = Function::parse(expression); - auto forest = FastForest::try_convert(function); + auto function = Function::parse(expression); + auto forest = FastForest::try_convert(*function); if ((tree_size <= 64) || is_little_endian()) { ASSERT_TRUE(forest); TEST_STATE(forest->impl_name().c_str()); - std::vector<double> inputs(function.num_params(), 0.5); - std::vector<double> inputs_nan(function.num_params(), std::numeric_limits<double>::quiet_NaN()); - double expected = eval_double(function, inputs); - double expected_nan = eval_double(function, inputs_nan); + std::vector<double> inputs(function->num_params(), 0.5); + std::vector<double> inputs_nan(function->num_params(), std::numeric_limits<double>::quiet_NaN()); + double expected = eval_double(*function, inputs); + double expected_nan = eval_double(*function, inputs_nan); auto ctx = forest->create_context(); EXPECT_EQUAL(expected, eval_ff(*forest, *ctx, inputs)); EXPECT_EQUAL(expected_nan, eval_ff(*forest, *ctx, inputs_nan)); @@ -410,31 +410,31 @@ TEST("require that fast forest evaluation is correct for all tree size categorie //----------------------------------------------------------------------------- TEST("require that GDBT expressions can be detected") { - Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+" - "if((d in [1]),10.0,if(!(e>=1),20.0,30.0))+" - "if((d in [1]),10.0,if(!(e>=1),20.0,30.0))"); - EXPECT_TRUE(contains_gbdt(function.root(), 9)); - EXPECT_TRUE(!contains_gbdt(function.root(), 10)); + auto function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+" + "if((d in [1]),10.0,if(!(e>=1),20.0,30.0))+" + "if((d in [1]),10.0,if(!(e>=1),20.0,30.0))"); + EXPECT_TRUE(contains_gbdt(function->root(), 9)); + EXPECT_TRUE(!contains_gbdt(function->root(), 10)); } TEST("require that wrapped GDBT expressions can be detected") { - Function function = Function::parse("10*(if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+" - "if((d in [1]),10.0,if((e<1),20.0,30.0))+" - "if((d in [1]),10.0,if((e<1),20.0,30.0)))"); - EXPECT_TRUE(contains_gbdt(function.root(), 9)); - EXPECT_TRUE(!contains_gbdt(function.root(), 10)); + auto function = Function::parse("10*(if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+" + "if((d in [1]),10.0,if((e<1),20.0,30.0))+" + "if((d in [1]),10.0,if((e<1),20.0,30.0)))"); + EXPECT_TRUE(contains_gbdt(function->root(), 9)); + EXPECT_TRUE(!contains_gbdt(function->root(), 10)); } TEST("require that lazy parameters are not suggested for GBDT models") { - Function function = Function::parse(Model().make_forest(10, 8)); - EXPECT_TRUE(!CompiledFunction::should_use_lazy_params(function)); + auto function = Function::parse(Model().make_forest(10, 8)); + EXPECT_TRUE(!CompiledFunction::should_use_lazy_params(*function)); } TEST("require that lazy parameters can be suggested for small GBDT models") { - Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+" - "if((d in [1]),10.0,if((e<1),20.0,30.0))+" - "if((d in [1]),10.0,if((e<1),20.0,30.0))"); - EXPECT_TRUE(CompiledFunction::should_use_lazy_params(function)); + auto function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+" + "if((d in [1]),10.0,if((e<1),20.0,30.0))+" + "if((d in [1]),10.0,if((e<1),20.0,30.0))"); + EXPECT_TRUE(CompiledFunction::should_use_lazy_params(*function)); } //----------------------------------------------------------------------------- diff --git a/eval/src/tests/eval/gbdt/model.cpp b/eval/src/tests/eval/gbdt/model.cpp index 8f0d87a4020..c2507d2f056 100644 --- a/eval/src/tests/eval/gbdt/model.cpp +++ b/eval/src/tests/eval/gbdt/model.cpp @@ -113,7 +113,7 @@ struct ForestParams { //----------------------------------------------------------------------------- -Function make_forest(const ForestParams ¶ms, size_t num_trees) { +auto make_forest(const ForestParams ¶ms, size_t num_trees) { return Function::parse(Model(params.model_seed) .less_percent(params.less_percent) .make_forest(num_trees, params.tree_size)); diff --git a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp index c9108ee74ce..d946d244d17 100644 --- a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp +++ b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp @@ -28,10 +28,10 @@ struct MyEvalTest : test::EvalSpec::EvalTest { virtual void next_expression(const std::vector<vespalib::string> ¶m_names, const vespalib::string &expression) override { - Function function = Function::parse(param_names, expression); - ASSERT_TRUE(!function.has_error()); + auto function = Function::parse(param_names, expression); + ASSERT_TRUE(!function->has_error()); bool is_supported = true; - bool has_issues = InterpretedFunction::detect_issues(function); + bool has_issues = InterpretedFunction::detect_issues(*function); if (is_supported == has_issues) { const char *supported_str = is_supported ? "supported" : "not supported"; const char *issues_str = has_issues ? "has issues" : "does not have issues"; @@ -46,16 +46,16 @@ struct MyEvalTest : test::EvalSpec::EvalTest { const vespalib::string &expression, double expected_result) override { - Function function = Function::parse(param_names, expression); - ASSERT_TRUE(!function.has_error()); + auto function = Function::parse(param_names, expression); + ASSERT_TRUE(!function->has_error()); bool is_supported = true; - bool has_issues = InterpretedFunction::detect_issues(function); + bool has_issues = InterpretedFunction::detect_issues(*function); if (is_supported && !has_issues) { vespalib::string desc = as_string(param_names, param_values, expression); SimpleParams params(param_values); - verify_result(SimpleTensorEngine::ref(), function, false, "[untyped simple] "+desc, params, expected_result); - verify_result(DefaultTensorEngine::ref(), function, false, "[untyped prod] "+desc, params, expected_result); - verify_result(DefaultTensorEngine::ref(), function, true, "[typed prod] "+desc, params, expected_result); + verify_result(SimpleTensorEngine::ref(), *function, false, "[untyped simple] "+desc, params, expected_result); + verify_result(DefaultTensorEngine::ref(), *function, false, "[untyped prod] "+desc, params, expected_result); + verify_result(DefaultTensorEngine::ref(), *function, true, "[typed prod] "+desc, params, expected_result); } } @@ -102,15 +102,15 @@ TEST_FF("require that compiled evaluation passes all conformance tests", MyEvalT TEST("require that invalid function is tagged with error") { std::vector<vespalib::string> params({"x", "y", "z", "w"}); - Function function = Function::parse(params, "x & y"); - EXPECT_TRUE(function.has_error()); + auto function = Function::parse(params, "x & y"); + EXPECT_TRUE(function->has_error()); } //----------------------------------------------------------------------------- size_t count_ifs(const vespalib::string &expr, std::initializer_list<double> params_in) { - Function fun = Function::parse(expr); - InterpretedFunction ifun(SimpleTensorEngine::ref(), fun, NodeTypes()); + auto fun = Function::parse(expr); + InterpretedFunction ifun(SimpleTensorEngine::ref(), *fun, NodeTypes()); InterpretedFunction::Context ctx(ifun); SimpleParams params(params_in); ifun.eval(ctx, params); @@ -137,8 +137,8 @@ TEST("require that function pointers can be passed as instruction parameters") { } TEST("require that basic addition works") { - Function function = Function::parse("a+10"); - InterpretedFunction interpreted(SimpleTensorEngine::ref(), function, NodeTypes()); + auto function = Function::parse("a+10"); + InterpretedFunction interpreted(SimpleTensorEngine::ref(), *function, NodeTypes()); InterpretedFunction::Context ctx(interpreted); SimpleParams params_20({20}); SimpleParams params_40({40}); @@ -153,20 +153,20 @@ TEST("require that functions with non-compilable lambdas cannot be interpreted") auto good_join = Function::parse("join(a,b,f(x,y)(x+y))"); auto bad_map = Function::parse("map(a,f(x)(map(x,f(i)(i+1))))"); auto bad_join = Function::parse("join(a,b,f(x,y)(join(x,y,f(i,j)(i+j))))"); - for (const Function *good: {&good_map, &good_join}) { + for (const Function *good: {good_map.get(), good_join.get()}) { if (!EXPECT_TRUE(!good->has_error())) { fprintf(stderr, "parse error: %s\n", good->get_error().c_str()); } EXPECT_TRUE(!InterpretedFunction::detect_issues(*good)); } - for (const Function *bad: {&bad_map, &bad_join}) { + for (const Function *bad: {bad_map.get(), bad_join.get()}) { if (!EXPECT_TRUE(!bad->has_error())) { fprintf(stderr, "parse error: %s\n", bad->get_error().c_str()); } EXPECT_TRUE(InterpretedFunction::detect_issues(*bad)); } std::cerr << "Example function issues:" << std::endl - << InterpretedFunction::detect_issues(bad_join).list + << InterpretedFunction::detect_issues(*bad_join).list << std::endl; } diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp index 6fbcea8a8a1..dceaf279594 100644 --- a/eval/src/tests/eval/node_types/node_types_test.cpp +++ b/eval/src/tests/eval/node_types/node_types_test.cpp @@ -19,27 +19,27 @@ struct TypeSpecExtractor : public vespalib::eval::SymbolExtractor { }; void verify(const vespalib::string &type_expr, const vespalib::string &type_spec) { - Function function = Function::parse(type_expr, TypeSpecExtractor()); - if (!EXPECT_TRUE(!function.has_error())) { - fprintf(stderr, "parse error: %s\n", function.get_error().c_str()); + auto function = Function::parse(type_expr, TypeSpecExtractor()); + if (!EXPECT_TRUE(!function->has_error())) { + fprintf(stderr, "parse error: %s\n", function->get_error().c_str()); return; } std::vector<ValueType> input_types; - for (size_t i = 0; i < function.num_params(); ++i) { - input_types.push_back(ValueType::from_spec(function.param_name(i))); + for (size_t i = 0; i < function->num_params(); ++i) { + input_types.push_back(ValueType::from_spec(function->param_name(i))); } - NodeTypes types(function, input_types); + NodeTypes types(*function, input_types); ValueType expected_type = ValueType::from_spec(type_spec); - ValueType actual_type = types.get_type(function.root()); + ValueType actual_type = types.get_type(function->root()); EXPECT_EQUAL(expected_type, actual_type); } TEST("require that error nodes have error type") { - Function function = Function::parse("1 2 3 4 5", TypeSpecExtractor()); - EXPECT_TRUE(function.has_error()); - NodeTypes types(function, std::vector<ValueType>()); + auto function = Function::parse("1 2 3 4 5", TypeSpecExtractor()); + EXPECT_TRUE(function->has_error()); + NodeTypes types(*function, std::vector<ValueType>()); ValueType expected_type = ValueType::from_spec("error"); - ValueType actual_type = types.get_type(function.root()); + ValueType actual_type = types.get_type(function->root()); EXPECT_EQUAL(expected_type, actual_type); } @@ -252,21 +252,21 @@ TEST("require that tensor concat resolves correct type") { } TEST("require that double only expressions can be detected") { - Function plain_fun = Function::parse("1+2"); - Function complex_fun = Function::parse("reduce(a,sum)"); - NodeTypes plain_types(plain_fun, {}); - NodeTypes complex_types(complex_fun, {ValueType::tensor_type({{"x"}})}); - EXPECT_TRUE(plain_types.get_type(plain_fun.root()).is_double()); - EXPECT_TRUE(complex_types.get_type(complex_fun.root()).is_double()); + auto plain_fun = Function::parse("1+2"); + auto complex_fun = Function::parse("reduce(a,sum)"); + NodeTypes plain_types(*plain_fun, {}); + NodeTypes complex_types(*complex_fun, {ValueType::tensor_type({{"x"}})}); + EXPECT_TRUE(plain_types.get_type(plain_fun->root()).is_double()); + EXPECT_TRUE(complex_types.get_type(complex_fun->root()).is_double()); EXPECT_TRUE(plain_types.all_types_are_double()); EXPECT_FALSE(complex_types.all_types_are_double()); } TEST("require that empty type repo works as expected") { NodeTypes types; - Function function = Function::parse("1+2"); - EXPECT_FALSE(function.has_error()); - EXPECT_TRUE(types.get_type(function.root()).is_error()); + auto function = Function::parse("1+2"); + EXPECT_FALSE(function->has_error()); + EXPECT_TRUE(types.get_type(function->root()).is_error()); EXPECT_FALSE(types.all_types_are_double()); } diff --git a/eval/src/tests/eval/param_usage/param_usage_test.cpp b/eval/src/tests/eval/param_usage/param_usage_test.cpp index c81be24a9ec..4a8278511b3 100644 --- a/eval/src/tests/eval/param_usage/param_usage_test.cpp +++ b/eval/src/tests/eval/param_usage/param_usage_test.cpp @@ -30,47 +30,47 @@ std::ostream &operator<<(std::ostream &out, const List &list) { TEST("require that simple expression has appropriate parameter usage") { std::vector<vespalib::string> params({"x", "y", "z"}); - Function function = Function::parse(params, "(x+y)*y"); - EXPECT_EQUAL(List(count_param_usage(function)), List({1.0, 2.0, 0.0})); - EXPECT_EQUAL(List(check_param_usage(function)), List({1.0, 1.0, 0.0})); + auto function = Function::parse(params, "(x+y)*y"); + EXPECT_EQUAL(List(count_param_usage(*function)), List({1.0, 2.0, 0.0})); + EXPECT_EQUAL(List(check_param_usage(*function)), List({1.0, 1.0, 0.0})); } TEST("require that if children have 50% probability each by default") { std::vector<vespalib::string> params({"x", "y", "z", "w"}); - Function function = Function::parse(params, "if(w,(x+y)*y,(y+z)*z)"); - EXPECT_EQUAL(List(count_param_usage(function)), List({0.5, 1.5, 1.0, 1.0})); - EXPECT_EQUAL(List(check_param_usage(function)), List({0.5, 1.0, 0.5, 1.0})); + auto function = Function::parse(params, "if(w,(x+y)*y,(y+z)*z)"); + EXPECT_EQUAL(List(count_param_usage(*function)), List({0.5, 1.5, 1.0, 1.0})); + EXPECT_EQUAL(List(check_param_usage(*function)), List({0.5, 1.0, 0.5, 1.0})); } TEST("require that if children probability can be adjusted") { std::vector<vespalib::string> params({"x", "y", "z"}); - Function function = Function::parse(params, "if(z,x*x,y*y,0.8)"); - EXPECT_EQUAL(List(count_param_usage(function)), List({1.6, 0.4, 1.0})); - EXPECT_EQUAL(List(check_param_usage(function)), List({0.8, 0.2, 1.0})); + auto function = Function::parse(params, "if(z,x*x,y*y,0.8)"); + EXPECT_EQUAL(List(count_param_usage(*function)), List({1.6, 0.4, 1.0})); + EXPECT_EQUAL(List(check_param_usage(*function)), List({0.8, 0.2, 1.0})); } TEST("require that chained if statements are combined correctly") { std::vector<vespalib::string> params({"x", "y", "z", "w"}); - Function function = Function::parse(params, "if(z,x,y)+if(w,y,x)"); - EXPECT_EQUAL(List(count_param_usage(function)), List({1.0, 1.0, 1.0, 1.0})); - EXPECT_EQUAL(List(check_param_usage(function)), List({0.75, 0.75, 1.0, 1.0})); + auto function = Function::parse(params, "if(z,x,y)+if(w,y,x)"); + EXPECT_EQUAL(List(count_param_usage(*function)), List({1.0, 1.0, 1.0, 1.0})); + EXPECT_EQUAL(List(check_param_usage(*function)), List({0.75, 0.75, 1.0, 1.0})); } TEST("require that multi-level if statements are combined correctly") { std::vector<vespalib::string> params({"x", "y", "z", "w"}); - Function function = Function::parse(params, "if(z,if(w,y*x,x*x),if(w,y*x,x*x))"); - EXPECT_EQUAL(List(count_param_usage(function)), List({1.5, 0.5, 1.0, 1.0})); - EXPECT_EQUAL(List(check_param_usage(function)), List({1.0, 0.5, 1.0, 1.0})); + auto function = Function::parse(params, "if(z,if(w,y*x,x*x),if(w,y*x,x*x))"); + EXPECT_EQUAL(List(count_param_usage(*function)), List({1.5, 0.5, 1.0, 1.0})); + EXPECT_EQUAL(List(check_param_usage(*function)), List({1.0, 0.5, 1.0, 1.0})); } TEST("require that lazy parameters are suggested for functions with parameters that might not be used") { - Function function = Function::parse("if(z,x,y)+if(w,y,x)"); - EXPECT_TRUE(CompiledFunction::should_use_lazy_params(function)); + auto function = Function::parse("if(z,x,y)+if(w,y,x)"); + EXPECT_TRUE(CompiledFunction::should_use_lazy_params(*function)); } TEST("require that lazy parameters are not suggested for functions where all parameters are always used") { - Function function = Function::parse("a*b*c"); - EXPECT_TRUE(!CompiledFunction::should_use_lazy_params(function)); + auto function = Function::parse("a*b*c"); + EXPECT_TRUE(!CompiledFunction::should_use_lazy_params(*function)); } TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/vespa/eval/eval/basic_nodes.h b/eval/src/vespa/eval/eval/basic_nodes.h index 06c296fe753..c1192585f7c 100644 --- a/eval/src/vespa/eval/eval/basic_nodes.h +++ b/eval/src/vespa/eval/eval/basic_nodes.h @@ -61,7 +61,7 @@ struct Node { virtual ~Node() {} }; -typedef std::unique_ptr<Node> Node_UP; +using Node_UP = std::unique_ptr<Node>; /** * Simple typecasting utility. Intended usage: diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp index baaaff25b34..108380f52b7 100644 --- a/eval/src/vespa/eval/eval/function.cpp +++ b/eval/src/vespa/eval/eval/function.cpp @@ -565,7 +565,7 @@ std::vector<vespalib::string> get_idents(ParseContext &ctx) { return list; } -Function parse_lambda(ParseContext &ctx, size_t num_params) { +auto parse_lambda(ParseContext &ctx, size_t num_params) { ctx.skip_spaces(); ctx.eat('f'); auto param_names = get_ident_list(ctx, true); @@ -581,13 +581,13 @@ Function parse_lambda(ParseContext &ctx, size_t num_params) { ctx.fail(make_string("expected lambda with %zu parameter(s), was %zu", num_params, param_names.size())); } - return Function(std::move(lambda_root), std::move(param_names)); + return Function::create(std::move(lambda_root), std::move(param_names)); } void parse_tensor_map(ParseContext &ctx) { Node_UP child = get_expression(ctx); ctx.eat(','); - Function lambda = parse_lambda(ctx, 1); + auto lambda = parse_lambda(ctx, 1); ctx.push_expression(std::make_unique<nodes::TensorMap>(std::move(child), std::move(lambda))); } @@ -596,7 +596,7 @@ void parse_tensor_join(ParseContext &ctx) { ctx.eat(','); Node_UP rhs = get_expression(ctx); ctx.eat(','); - Function lambda = parse_lambda(ctx, 2); + auto lambda = parse_lambda(ctx, 2); ctx.push_expression(std::make_unique<nodes::TensorJoin>(std::move(lhs), std::move(rhs), std::move(lambda))); } @@ -994,15 +994,15 @@ void parse_expression(ParseContext &ctx) { } } -Function parse_function(const Params ¶ms, vespalib::stringref expression, - const SymbolExtractor *symbol_extractor) +auto parse_function(const Params ¶ms, vespalib::stringref expression, + const SymbolExtractor *symbol_extractor) { ParseContext ctx(params, expression.data(), expression.size(), symbol_extractor); parse_expression(ctx); if (ctx.failed() && params.implicit()) { - return Function(ctx.get_result(), std::vector<vespalib::string>()); + return Function::create(ctx.get_result(), std::vector<vespalib::string>()); } - return Function(ctx.get_result(), params.extract()); + return Function::create(ctx.get_result(), params.extract()); } } // namespace vespalib::<unnamed> @@ -1023,25 +1023,31 @@ Function::get_error() const return error ? error->message() : ""; } -Function +std::shared_ptr<Function const> +Function::create(nodes::Node_UP root_in, std::vector<vespalib::string> params_in) +{ + return std::make_shared<Function const>(std::move(root_in), std::move(params_in), ctor_tag()); +} + +std::shared_ptr<Function const> Function::parse(vespalib::stringref expression) { return parse_function(ImplicitParams(), expression, nullptr); } -Function +std::shared_ptr<Function const> Function::parse(vespalib::stringref expression, const SymbolExtractor &symbol_extractor) { return parse_function(ImplicitParams(), expression, &symbol_extractor); } -Function +std::shared_ptr<Function const> Function::parse(const std::vector<vespalib::string> ¶ms, vespalib::stringref expression) { return parse_function(ExplicitParams(params), expression, nullptr); } -Function +std::shared_ptr<Function const> Function::parse(const std::vector<vespalib::string> ¶ms, vespalib::stringref expression, const SymbolExtractor &symbol_extractor) { diff --git a/eval/src/vespa/eval/eval/function.h b/eval/src/vespa/eval/eval/function.h index b3505aad141..b9b9d091961 100644 --- a/eval/src/vespa/eval/eval/function.h +++ b/eval/src/vespa/eval/eval/function.h @@ -34,28 +34,33 @@ struct NodeVisitor; * AST root and the names of all parameters. A function can only be * evaluated using the appropriate number of parameters. **/ -class Function +class Function : public std::enable_shared_from_this<Function> { private: nodes::Node_UP _root; std::vector<vespalib::string> _params; + struct ctor_tag {}; + public: - Function() : _root(new nodes::Number(0.0)), _params() {} - Function(nodes::Node_UP root_in, std::vector<vespalib::string> &¶ms_in) + Function(nodes::Node_UP root_in, std::vector<vespalib::string> params_in, ctor_tag) : _root(std::move(root_in)), _params(std::move(params_in)) {} - Function(Function &&rhs) : _root(std::move(rhs._root)), _params(std::move(rhs._params)) {} + Function(Function &&rhs) = delete; + Function(const Function &rhs) = delete; + Function &operator=(Function &&rhs) = delete; + Function &operator=(const Function &rhs) = delete; ~Function() { delete_node(std::move(_root)); } size_t num_params() const { return _params.size(); } vespalib::stringref param_name(size_t idx) const { return _params[idx]; } bool has_error() const; vespalib::string get_error() const; const nodes::Node &root() const { return *_root; } - static Function parse(vespalib::stringref expression); - static Function parse(vespalib::stringref expression, const SymbolExtractor &symbol_extractor); - static Function parse(const std::vector<vespalib::string> ¶ms, vespalib::stringref expression); - static Function parse(const std::vector<vespalib::string> ¶ms, vespalib::stringref expression, - const SymbolExtractor &symbol_extractor); + static std::shared_ptr<Function const> create(nodes::Node_UP root_in, std::vector<vespalib::string> params_in); + static std::shared_ptr<Function const> parse(vespalib::stringref expression); + static std::shared_ptr<Function const> parse(vespalib::stringref expression, const SymbolExtractor &symbol_extractor); + static std::shared_ptr<Function const> parse(const std::vector<vespalib::string> ¶ms, vespalib::stringref expression); + static std::shared_ptr<Function const> parse(const std::vector<vespalib::string> ¶ms, vespalib::stringref expression, + const SymbolExtractor &symbol_extractor); vespalib::string dump() const { nodes::DumpContext dump_context(_params); return _root->dump(dump_context); diff --git a/eval/src/vespa/eval/eval/llvm/compile_cache.cpp b/eval/src/vespa/eval/eval/llvm/compile_cache.cpp index f4dfad32121..4f93c496f2a 100644 --- a/eval/src/vespa/eval/eval/llvm/compile_cache.cpp +++ b/eval/src/vespa/eval/eval/llvm/compile_cache.cpp @@ -2,13 +2,21 @@ #include "compile_cache.h" #include <vespa/eval/eval/key_gen.h> -#include <thread> namespace vespalib { namespace eval { -std::mutex CompileCache::_lock; -CompileCache::Map CompileCache::_cached; +std::mutex CompileCache::_lock{}; +CompileCache::Map CompileCache::_cached{}; +Executor *CompileCache::_executor{nullptr}; + +const CompiledFunction & +CompileCache::Value::wait_for_result() +{ + std::unique_lock<std::mutex> guard(_lock); + cond.wait(guard, [this](){ return bool(compiled_function); }); + return *compiled_function; +} void CompileCache::release(Map::iterator entry) @@ -22,11 +30,45 @@ CompileCache::release(Map::iterator entry) CompileCache::Token::UP CompileCache::compile(const Function &function, PassParams pass_params) { + Token::UP token; + CompileTask::UP task; + vespalib::string key = gen_key(function, pass_params); + { + std::lock_guard<std::mutex> guard(_lock); + auto pos = _cached.find(key); + if (pos != _cached.end()) { + ++(pos->second.num_refs); + token = std::make_unique<Token>(pos, Token::ctor_tag()); + } else { + auto res = _cached.emplace(std::move(key), Value::ctor_tag()); + assert(res.second); + token = std::make_unique<Token>(res.first, Token::ctor_tag()); + ++(res.first->second.num_refs); + task = std::make_unique<CompileTask>(function, pass_params, + std::make_unique<Token>(res.first, Token::ctor_tag())); + if (_executor != nullptr) { + task = _executor->execute(std::move(task)); + } + } + } + if (task) { + task->run(); + } + return token; +} + +void +CompileCache::attach_executor(Executor &executor) +{ + std::lock_guard<std::mutex> guard(_lock); + _executor = &executor; +} + +void +CompileCache::detach_executor() +{ std::lock_guard<std::mutex> guard(_lock); - CompileContext compile_ctx(function, pass_params); - std::thread thread(do_compile, std::ref(compile_ctx)); - thread.join(); - return std::move(compile_ctx.token); + _executor = nullptr; } size_t @@ -47,18 +89,28 @@ CompileCache::count_refs() return refs; } -void -CompileCache::do_compile(CompileContext &ctx) { - vespalib::string key = gen_key(ctx.function, ctx.pass_params); - auto pos = _cached.find(key); - if (pos != _cached.end()) { - ++(pos->second.num_refs); - ctx.token.reset(new Token(pos)); - } else { - auto res = _cached.emplace(std::move(key), Value(CompiledFunction(ctx.function, ctx.pass_params))); - assert(res.second); - ctx.token.reset(new Token(res.first)); +size_t +CompileCache::count_pending() +{ + std::lock_guard<std::mutex> guard(_lock); + size_t pending = 0; + for (const auto &entry: _cached) { + if (entry.second.compiled_function.get() == nullptr) { + ++pending; + } } + return pending; +} + +void +CompileCache::CompileTask::run() +{ + auto &entry = token->_entry->second; + auto result = std::make_unique<CompiledFunction>(*function, pass_params); + std::lock_guard<std::mutex> guard(_lock); + entry.compiled_function = std::move(result); + entry.cf.store(entry.compiled_function.get(), std::memory_order_release); + entry.cond.notify_all(); } } // namespace vespalib::eval diff --git a/eval/src/vespa/eval/eval/llvm/compile_cache.h b/eval/src/vespa/eval/eval/llvm/compile_cache.h index 3fa721b27bf..0adca808256 100644 --- a/eval/src/vespa/eval/eval/llvm/compile_cache.h +++ b/eval/src/vespa/eval/eval/llvm/compile_cache.h @@ -3,6 +3,8 @@ #pragma once #include "compiled_function.h" +#include <vespa/vespalib/util/executor.h> +#include <condition_variable> #include <mutex> namespace vespalib { @@ -19,15 +21,27 @@ namespace eval { class CompileCache { private: - typedef vespalib::string Key; + using Key = vespalib::string; struct Value { size_t num_refs; - CompiledFunction cf; - Value(CompiledFunction &&cf_in) : num_refs(1), cf(std::move(cf_in)) {} + std::atomic<const CompiledFunction *> cf; + std::condition_variable cond; + CompiledFunction::UP compiled_function; + struct ctor_tag {}; + Value(ctor_tag) : num_refs(1), cf(nullptr), cond(), compiled_function() {} + const CompiledFunction &wait_for_result(); + const CompiledFunction &get() { + const CompiledFunction *ptr = cf.load(std::memory_order_acquire); + if (ptr == nullptr) { + return wait_for_result(); + } + return *ptr; + } }; - typedef std::map<Key,Value> Map; + using Map = std::map<Key,Value>; static std::mutex _lock; static Map _cached; + static Executor *_executor; static void release(Map::iterator entry); @@ -36,33 +50,37 @@ public: { private: friend class CompileCache; - CompileCache::Map::iterator entry; - explicit Token(CompileCache::Map::iterator entry_in) - : entry(entry_in) {} + friend class CompileTask; + struct ctor_tag {}; + CompileCache::Map::iterator _entry; public: - typedef std::unique_ptr<Token> UP; - const CompiledFunction &get() const { return entry->second.cf; } - ~Token() { CompileCache::release(entry); } + Token(Token &&) = delete; + Token(const Token &) = delete; + Token &operator=(Token &&) = delete; + Token &operator=(const Token &) = delete; + using UP = std::unique_ptr<Token>; + explicit Token(CompileCache::Map::iterator entry, ctor_tag) : _entry(entry) {} + const CompiledFunction &get() const { return _entry->second.get(); } + ~Token() { CompileCache::release(_entry); } }; + static Token::UP compile(const Function &function, PassParams pass_params); + static void attach_executor(Executor &executor); + static void detach_executor(); static size_t num_cached(); static size_t count_refs(); + static size_t count_pending(); private: - struct CompileContext { - const Function &function; + struct CompileTask : public Executor::Task { + std::shared_ptr<Function const> function; PassParams pass_params; Token::UP token; - CompileContext(const Function &function_in, - PassParams pass_params_in) - : function(function_in), - pass_params(pass_params_in), - token() {} + CompileTask(const Function &function_in, PassParams pass_params_in, Token::UP token_in) + : function(function_in.shared_from_this()), pass_params(pass_params_in), token(std::move(token_in)) {} + void run() override; }; - - static void do_compile(CompileContext &ctx); }; } // namespace vespalib::eval } // namespace vespalib - diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp index e5b7d45b0e2..72406f09a0d 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp @@ -643,6 +643,7 @@ FunctionBuilder::~FunctionBuilder() { } struct InitializeNativeTarget { InitializeNativeTarget() { + assert(llvm::llvm_is_multithreaded()); llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); llvm::InitializeNativeTargetAsmParser(); @@ -658,8 +659,6 @@ struct InitializeNativeTarget { } } initialize_native_target; -std::recursive_mutex LLVMWrapper::_global_llvm_lock; - LLVMWrapper::LLVMWrapper() : _context(), _module(), @@ -668,17 +667,14 @@ LLVMWrapper::LLVMWrapper() _forests(), _plugin_state() { - std::lock_guard<std::recursive_mutex> guard(_global_llvm_lock); _context = std::make_unique<llvm::LLVMContext>(); - _module = std::make_unique< llvm::Module>("LLVMWrapper", *_context); + _module = std::make_unique<llvm::Module>("LLVMWrapper", *_context); } - size_t LLVMWrapper::make_function(size_t num_params, PassParams pass_params, const Node &root, const gbdt::Optimize::Chain &forest_optimizers) { - std::lock_guard<std::recursive_mutex> guard(_global_llvm_lock); size_t function_id = _functions.size(); FunctionBuilder builder(*_context, *_module, vespalib::make_string("f%zu", function_id), @@ -692,7 +688,6 @@ LLVMWrapper::make_function(size_t num_params, PassParams pass_params, const Node size_t LLVMWrapper::make_forest_fragment(size_t num_params, const std::vector<const Node *> &fragment) { - std::lock_guard<std::recursive_mutex> guard(_global_llvm_lock); size_t function_id = _functions.size(); FunctionBuilder builder(*_context, *_module, vespalib::make_string("f%zu", function_id), @@ -706,7 +701,6 @@ LLVMWrapper::make_forest_fragment(size_t num_params, const std::vector<const Nod void LLVMWrapper::compile(llvm::raw_ostream * dumpStream) { - std::lock_guard<std::recursive_mutex> guard(_global_llvm_lock); if (dumpStream) { _module->print(*dumpStream, nullptr); } @@ -718,12 +712,10 @@ LLVMWrapper::compile(llvm::raw_ostream * dumpStream) void * LLVMWrapper::get_function_address(size_t function_id) { - std::lock_guard<std::recursive_mutex> guard(_global_llvm_lock); return _engine->getPointerToFunction(_functions[function_id]); } LLVMWrapper::~LLVMWrapper() { - std::lock_guard<std::recursive_mutex> guard(_global_llvm_lock); _plugin_state.clear(); _forests.clear(); _functions.clear(); diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h index db29771ee9e..040c0bdb73f 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h @@ -49,8 +49,6 @@ private: std::vector<gbdt::Forest::UP> _forests; std::vector<PluginState::UP> _plugin_state; - static std::recursive_mutex _global_llvm_lock; - void compile(llvm::raw_ostream * dumpStream); public: LLVMWrapper(); diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp index a7215eabcb9..803047d27c4 100644 --- a/eval/src/vespa/eval/eval/make_tensor_function.cpp +++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp @@ -167,8 +167,8 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { for (size_t i = 0; i < node.num_entries(); ++i) { my_in->add_entry(std::make_unique<Number>(node.get_entry(i).get_const_value())); } - Function my_fun(std::move(my_in), {"x"}); - const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(my_fun, PassParams::SEPARATE)); + auto my_fun = Function::create(std::move(my_in), {"x"}); + const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(*my_fun, PassParams::SEPARATE)); make_map(node, token.get()->get().get_function<1>()); } void visit(const Neg &node) override { diff --git a/eval/src/vespa/eval/eval/tensor_nodes.h b/eval/src/vespa/eval/eval/tensor_nodes.h index a307faf9b36..4213809f9e3 100644 --- a/eval/src/vespa/eval/eval/tensor_nodes.h +++ b/eval/src/vespa/eval/eval/tensor_nodes.h @@ -17,18 +17,18 @@ namespace nodes { class TensorMap : public Node { private: - Node_UP _child; - Function _lambda; + Node_UP _child; + std::shared_ptr<Function const> _lambda; public: - TensorMap(Node_UP child, Function lambda) + TensorMap(Node_UP child, std::shared_ptr<Function const> lambda) : _child(std::move(child)), _lambda(std::move(lambda)) {} - const Function &lambda() const { return _lambda; } + const Function &lambda() const { return *_lambda; } vespalib::string dump(DumpContext &ctx) const override { vespalib::string str; str += "map("; str += _child->dump(ctx); str += ","; - str += _lambda.dump_as_lambda(); + str += _lambda->dump_as_lambda(); str += ")"; return str; } @@ -46,13 +46,13 @@ public: class TensorJoin : public Node { private: - Node_UP _lhs; - Node_UP _rhs; - Function _lambda; + Node_UP _lhs; + Node_UP _rhs; + std::shared_ptr<Function const> _lambda; public: - TensorJoin(Node_UP lhs, Node_UP rhs, Function lambda) + TensorJoin(Node_UP lhs, Node_UP rhs, std::shared_ptr<Function const> lambda) : _lhs(std::move(lhs)), _rhs(std::move(rhs)), _lambda(std::move(lambda)) {} - const Function &lambda() const { return _lambda; } + const Function &lambda() const { return *_lambda; } vespalib::string dump(DumpContext &ctx) const override { vespalib::string str; str += "join("; @@ -60,7 +60,7 @@ public: str += ","; str += _rhs->dump(ctx); str += ","; - str += _lambda.dump_as_lambda(); + str += _lambda->dump_as_lambda(); str += ")"; return str; } diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.cpp b/eval/src/vespa/eval/eval/test/eval_fixture.cpp index 30e55d0418b..e578b28da18 100644 --- a/eval/src/vespa/eval/eval/test/eval_fixture.cpp +++ b/eval/src/vespa/eval/eval/test/eval_fixture.cpp @@ -91,14 +91,14 @@ EvalFixture::EvalFixture(const TensorEngine &engine, : _engine(engine), _stash(), _function(Function::parse(expr)), - _node_types(get_types(_function, param_repo)), - _mutable_set(get_mutable(_function, param_repo)), - _plain_tensor_function(make_tensor_function(_engine, _function.root(), _node_types, _stash)), + _node_types(get_types(*_function, param_repo)), + _mutable_set(get_mutable(*_function, param_repo)), + _plain_tensor_function(make_tensor_function(_engine, _function->root(), _node_types, _stash)), _patched_tensor_function(maybe_patch(allow_mutable, _plain_tensor_function, _mutable_set, _stash)), _tensor_function(optimized ? _engine.optimize(_patched_tensor_function, _stash) : _patched_tensor_function), _ifun(_engine, _tensor_function), _ictx(_ifun), - _param_values(make_params(_engine, _function, param_repo)), + _param_values(make_params(_engine, *_function, param_repo)), _params(get_refs(_param_values)), _result(_engine.to_spec(_ifun.eval(_ictx, _params))) { diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.h b/eval/src/vespa/eval/eval/test/eval_fixture.h index 9c793c01861..48f6a7e5d2e 100644 --- a/eval/src/vespa/eval/eval/test/eval_fixture.h +++ b/eval/src/vespa/eval/eval/test/eval_fixture.h @@ -41,19 +41,19 @@ public: }; private: - const TensorEngine &_engine; - Stash _stash; - Function _function; - NodeTypes _node_types; - std::set<size_t> _mutable_set; - const TensorFunction &_plain_tensor_function; - const TensorFunction &_patched_tensor_function; - const TensorFunction &_tensor_function; - InterpretedFunction _ifun; - InterpretedFunction::Context _ictx; - std::vector<Value::UP> _param_values; - SimpleObjectParams _params; - TensorSpec _result; + const TensorEngine &_engine; + Stash _stash; + std::shared_ptr<Function const> _function; + NodeTypes _node_types; + std::set<size_t> _mutable_set; + const TensorFunction &_plain_tensor_function; + const TensorFunction &_patched_tensor_function; + const TensorFunction &_tensor_function; + InterpretedFunction _ifun; + InterpretedFunction::Context _ictx; + std::vector<Value::UP> _param_values; + SimpleObjectParams _params; + TensorSpec _result; template <typename T> void find_all(const TensorFunction &node, std::vector<const T *> &list) { diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp index 3af7ea0a1f3..9242b19310d 100644 --- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp +++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp @@ -119,12 +119,12 @@ struct Expr_V : Eval { const vespalib::string &expr; Expr_V(const vespalib::string &expr_in) : expr(expr_in) {} Result eval(const TensorEngine &engine) const override { - Function fun = Function::parse(expr); - NodeTypes types(fun, {}); - InterpretedFunction ifun(engine, fun, types); + auto fun = Function::parse(expr); + NodeTypes types(*fun, {}); + InterpretedFunction ifun(engine, *fun, types); InterpretedFunction::Context ctx(ifun); SimpleObjectParams params({}); - return Result(engine, check_type(ifun.eval(ctx, params), types.get_type(fun.root()))); + return Result(engine, check_type(ifun.eval(ctx, params), types.get_type(fun->root()))); } }; @@ -133,14 +133,14 @@ struct Expr_T : Eval { const vespalib::string &expr; Expr_T(const vespalib::string &expr_in) : expr(expr_in) {} Result eval(const TensorEngine &engine, const TensorSpec &a) const override { - Function fun = Function::parse(expr); + auto fun = Function::parse(expr); auto a_type = ValueType::from_spec(a.type()); - NodeTypes types(fun, {a_type}); - InterpretedFunction ifun(engine, fun, types); + NodeTypes types(*fun, {a_type}); + InterpretedFunction ifun(engine, *fun, types); InterpretedFunction::Context ctx(ifun); Value::UP va = engine.from_spec(a); SimpleObjectParams params({*va}); - return Result(engine, check_type(ifun.eval(ctx, params), types.get_type(fun.root()))); + return Result(engine, check_type(ifun.eval(ctx, params), types.get_type(fun->root()))); } }; @@ -149,16 +149,16 @@ struct Expr_TT : Eval { vespalib::string expr; Expr_TT(const vespalib::string &expr_in) : expr(expr_in) {} Result eval(const TensorEngine &engine, const TensorSpec &a, const TensorSpec &b) const override { - Function fun = Function::parse(expr); + auto fun = Function::parse(expr); auto a_type = ValueType::from_spec(a.type()); auto b_type = ValueType::from_spec(b.type()); - NodeTypes types(fun, {a_type, b_type}); - InterpretedFunction ifun(engine, fun, types); + NodeTypes types(*fun, {a_type, b_type}); + InterpretedFunction ifun(engine, *fun, types); InterpretedFunction::Context ctx(ifun); Value::UP va = engine.from_spec(a); Value::UP vb = engine.from_spec(b); SimpleObjectParams params({*va,*vb}); - return Result(engine, check_type(ifun.eval(ctx, params), types.get_type(fun.root()))); + return Result(engine, check_type(ifun.eval(ctx, params), types.get_type(fun->root()))); } }; diff --git a/searchlib/src/apps/vespa-ranking-expression-analyzer/vespa-ranking-expression-analyzer.cpp b/searchlib/src/apps/vespa-ranking-expression-analyzer/vespa-ranking-expression-analyzer.cpp index 0f49fb05f35..30177dbe693 100644 --- a/searchlib/src/apps/vespa-ranking-expression-analyzer/vespa-ranking-expression-analyzer.cpp +++ b/searchlib/src/apps/vespa-ranking-expression-analyzer/vespa-ranking-expression-analyzer.cpp @@ -237,9 +237,11 @@ bool vmforest_used(const std::vector<Forest::UP> &forests) { //----------------------------------------------------------------------------- struct State { + using FunPtr = std::shared_ptr<Function const>; + vespalib::string name; vespalib::string expression; - Function function; + FunPtr function; FunctionInfo fun_info; CompiledFunction::UP compiled_function; @@ -256,7 +258,7 @@ struct State { BenchmarkTimer timer(1.0); while (timer.has_budget()) { timer.before(); - CompiledFunction::UP new_cf(new CompiledFunction(function, PassParams::ARRAY)); + CompiledFunction::UP new_cf(new CompiledFunction(*function, PassParams::ARRAY)); timer.after(); compiled_function = std::move(new_cf); } @@ -265,12 +267,12 @@ struct State { void benchmark_option(const vespalib::string &opt_name, Optimize::Chain optimizer_chain) { options.push_back(opt_name); - options_us.push_back(CompiledFunction(function, PassParams::ARRAY, optimizer_chain).estimate_cost_us(fun_info.params)); + options_us.push_back(CompiledFunction(*function, PassParams::ARRAY, optimizer_chain).estimate_cost_us(fun_info.params)); fprintf(stderr, " option '%s' execute time: %g us\n", opt_name.c_str(), options_us.back()); } void maybe_benchmark_fast_forest() { - auto ff = FastForest::try_convert(function); + auto ff = FastForest::try_convert(*function); if (ff) { vespalib::string opt_name("ff"); options.push_back(opt_name); @@ -313,7 +315,7 @@ State::State(const vespalib::string &file_name, vespalib::string expression_in) : name(strip_name(file_name)), expression(std::move(expression_in)), function(Function::parse(expression, FeatureNameExtractor())), - fun_info(function), + fun_info(*function), compiled_function(), llvm_compile_s(0.0), llvm_execute_us(0.0), @@ -355,8 +357,8 @@ MyApp::Main() return 1; } State state(file_name, file.get().make_string()); - if (state.function.has_error()) { - vespalib::string error_message = state.function.get_error(); + if (state.function->has_error()) { + vespalib::string error_message = state.function->get_error(); fprintf(stderr, "input file (%s) contains an illegal expression:\n%s\n", file_name.c_str(), error_message.c_str()); return 1; diff --git a/searchlib/src/tests/features/max_reduce_prod_join_replacer/max_reduce_prod_join_replacer_test.cpp b/searchlib/src/tests/features/max_reduce_prod_join_replacer/max_reduce_prod_join_replacer_test.cpp index c9c8124bb94..1d1c03ab56d 100644 --- a/searchlib/src/tests/features/max_reduce_prod_join_replacer/max_reduce_prod_join_replacer_test.cpp +++ b/searchlib/src/tests/features/max_reduce_prod_join_replacer/max_reduce_prod_join_replacer_test.cpp @@ -46,11 +46,11 @@ struct MyBlueprint : Blueprint { bool replaced(const vespalib::string &expr) { bool was_used = false; ExpressionReplacer::UP replacer = MaxReduceProdJoinReplacer::create(std::make_unique<MyBlueprint>(was_used)); - Function rank_function = Function::parse(expr, FeatureNameExtractor()); - if (!EXPECT_TRUE(!rank_function.has_error())) { - fprintf(stderr, "parse error: %s\n", rank_function.dump().c_str()); + auto rank_function = Function::parse(expr, FeatureNameExtractor()); + if (!EXPECT_TRUE(!rank_function->has_error())) { + fprintf(stderr, "parse error: %s\n", rank_function->dump().c_str()); } - auto result = replacer->maybe_replace(rank_function, IndexEnvironment()); + auto result = replacer->maybe_replace(*rank_function, IndexEnvironment()); EXPECT_EQUAL(bool(result), was_used); return was_used; } diff --git a/searchlib/src/vespa/searchlib/features/element_similarity_feature.cpp b/searchlib/src/vespa/searchlib/features/element_similarity_feature.cpp index a9f31265f49..0676b0a46c4 100644 --- a/searchlib/src/vespa/searchlib/features/element_similarity_feature.cpp +++ b/searchlib/src/vespa/searchlib/features/element_similarity_feature.cpp @@ -416,13 +416,13 @@ ElementSimilarityBlueprint::setup(const fef::IIndexEnvironment &env, const fef:: return false; } std::vector<vespalib::string> args({"p", "o", "q", "f", "w"}); - vespalib::eval::Function function = vespalib::eval::Function::parse(args, expr); - if (function.has_error()) { + auto function = vespalib::eval::Function::parse(args, expr); + if (function->has_error()) { LOG(warning, "'%s': per-element expression parse error: %s", - fnb.buildName().c_str(), function.get_error().c_str()); + fnb.buildName().c_str(), function->get_error().c_str()); return false; } - _outputs.push_back(OutputContext_UP(new OutputContext(function, std::move(aggr)))); + _outputs.push_back(OutputContext_UP(new OutputContext(*function, std::move(aggr)))); } env.hintFieldAccess(field->id()); return true; diff --git a/searchlib/src/vespa/searchlib/features/rankingexpressionfeature.cpp b/searchlib/src/vespa/searchlib/features/rankingexpressionfeature.cpp index 0246f96f1df..62860e04c73 100644 --- a/searchlib/src/vespa/searchlib/features/rankingexpressionfeature.cpp +++ b/searchlib/src/vespa/searchlib/features/rankingexpressionfeature.cpp @@ -256,12 +256,12 @@ RankingExpressionBlueprint::setup(const fef::IIndexEnvironment &env, LOG(error, "No expression given."); return false; } - Function rank_function = Function::parse(script, rankingexpression::FeatureNameExtractor()); - if (rank_function.has_error()) { - LOG(error, "Failed to parse expression '%s': %s", script.c_str(), rank_function.get_error().c_str()); + auto rank_function = Function::parse(script, rankingexpression::FeatureNameExtractor()); + if (rank_function->has_error()) { + LOG(error, "Failed to parse expression '%s': %s", script.c_str(), rank_function->get_error().c_str()); return false; } - _intrinsic_expression = _expression_replacer->maybe_replace(rank_function, env); + _intrinsic_expression = _expression_replacer->maybe_replace(*rank_function, env); if (_intrinsic_expression) { LOG(info, "%s replaced with %s", getName().c_str(), _intrinsic_expression->describe_self().c_str()); describeOutput("out", "result of intrinsic expression", _intrinsic_expression->result_type()); @@ -269,8 +269,8 @@ RankingExpressionBlueprint::setup(const fef::IIndexEnvironment &env, } bool do_compile = true; std::vector<ValueType> input_types; - for (size_t i = 0; i < rank_function.num_params(); ++i) { - const FeatureType &input = defineInput(rank_function.param_name(i), AcceptInput::ANY); + for (size_t i = 0; i < rank_function->num_params(); ++i) { + const FeatureType &input = defineInput(rank_function->param_name(i), AcceptInput::ANY); _input_is_object.push_back(char(input.is_object())); if (input.is_object()) { do_compile = false; @@ -279,17 +279,17 @@ RankingExpressionBlueprint::setup(const fef::IIndexEnvironment &env, input_types.push_back(ValueType::double_type()); } } - NodeTypes node_types(rank_function, input_types); + NodeTypes node_types(*rank_function, input_types); if (!node_types.all_types_are_double()) { do_compile = false; } - ValueType root_type = node_types.get_type(rank_function.root()); + ValueType root_type = node_types.get_type(rank_function->root()); if (root_type.is_error()) { LOG(error, "rank expression contains type errors: %s\n", script.c_str()); return false; } - auto compile_issues = CompiledFunction::detect_issues(rank_function); - auto interpret_issues = InterpretedFunction::detect_issues(rank_function); + auto compile_issues = CompiledFunction::detect_issues(*rank_function); + auto interpret_issues = InterpretedFunction::detect_issues(*rank_function); if (do_compile && compile_issues && !interpret_issues) { LOG(warning, "rank expression compilation disabled: %s\n%s", script.c_str(), list_issues(compile_issues.list).c_str()); @@ -306,18 +306,18 @@ RankingExpressionBlueprint::setup(const fef::IIndexEnvironment &env, if (do_compile) { // fast forest evaluation is a possible replacement for compiled tree models if (fef::indexproperties::eval::UseFastForest::check(env.getProperties())) { - _fast_forest = FastForest::try_convert(rank_function); + _fast_forest = FastForest::try_convert(*rank_function); } if (!_fast_forest) { - bool suggest_lazy = CompiledFunction::should_use_lazy_params(rank_function); + bool suggest_lazy = CompiledFunction::should_use_lazy_params(*rank_function); if (fef::indexproperties::eval::LazyExpressions::check(env.getProperties(), suggest_lazy)) { - _compile_token = CompileCache::compile(rank_function, PassParams::LAZY); + _compile_token = CompileCache::compile(*rank_function, PassParams::LAZY); } else { - _compile_token = CompileCache::compile(rank_function, PassParams::ARRAY); + _compile_token = CompileCache::compile(*rank_function, PassParams::ARRAY); } } } else { - _interpreted_function.reset(new InterpretedFunction(DefaultTensorEngine::ref(), rank_function, node_types)); + _interpreted_function.reset(new InterpretedFunction(DefaultTensorEngine::ref(), *rank_function, node_types)); } } FeatureType output_type = do_compile diff --git a/vespalib/src/vespa/vespalib/util/threadstackexecutorbase.cpp b/vespalib/src/vespa/vespalib/util/threadstackexecutorbase.cpp index 969b5e6f61e..e933477bfb9 100644 --- a/vespalib/src/vespa/vespalib/util/threadstackexecutorbase.cpp +++ b/vespalib/src/vespa/vespalib/util/threadstackexecutorbase.cpp @@ -177,6 +177,13 @@ ThreadStackExecutorBase::internalSetTaskLimit(uint32_t taskLimit) } } +size_t +ThreadStackExecutorBase::num_idle_workers() const +{ + LockGuard lock(_monitor); + return _workers.size(); +} + ThreadStackExecutorBase::Stats ThreadStackExecutorBase::getStats() { diff --git a/vespalib/src/vespa/vespalib/util/threadstackexecutorbase.h b/vespalib/src/vespa/vespalib/util/threadstackexecutorbase.h index 21a6e9cabe0..8718b04d2d3 100644 --- a/vespalib/src/vespa/vespalib/util/threadstackexecutorbase.h +++ b/vespalib/src/vespa/vespalib/util/threadstackexecutorbase.h @@ -198,6 +198,12 @@ protected: public: ThreadStackExecutorBase(const ThreadStackExecutorBase &) = delete; ThreadStackExecutorBase & operator = (const ThreadStackExecutorBase &) = delete; + + /** + * Returns the number of idle workers. This is mostly useful for testing. + **/ + size_t num_idle_workers() const; + /** * Observe and reset stats for this object. * |