summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2019-12-10 10:58:33 +0000
committerHåvard Pettersen <havardpe@oath.com>2019-12-12 12:05:21 +0000
commit970fa07a1d8ae88dafd5ed35bf1f98fae46af0fb (patch)
tree01615c821781f3d87866ad5c9aa0bce46eabd84f /eval
parent2de75cc16ce069fd9b36994101e4dd76de693220 (diff)
async concurrent llvm compilation
Diffstat (limited to 'eval')
-rw-r--r--eval/src/apps/eval_expr/eval_expr.cpp8
-rw-r--r--eval/src/apps/tensor_conformance/tensor_conformance.cpp22
-rw-r--r--eval/src/tests/eval/compile_cache/compile_cache_test.cpp153
-rw-r--r--eval/src/tests/eval/compiled_function/compiled_function_test.cpp94
-rw-r--r--eval/src/tests/eval/function/function_test.cpp566
-rw-r--r--eval/src/tests/eval/function_speed/function_speed_test.cpp14
-rw-r--r--eval/src/tests/eval/gbdt/fast_forest_bench.cpp8
-rw-r--r--eval/src/tests/eval/gbdt/gbdt_benchmark.cpp8
-rw-r--r--eval/src/tests/eval/gbdt/gbdt_test.cpp140
-rw-r--r--eval/src/tests/eval/gbdt/model.cpp2
-rw-r--r--eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp36
-rw-r--r--eval/src/tests/eval/node_types/node_types_test.cpp40
-rw-r--r--eval/src/tests/eval/param_usage/param_usage_test.cpp38
-rw-r--r--eval/src/vespa/eval/eval/basic_nodes.h2
-rw-r--r--eval/src/vespa/eval/eval/function.cpp30
-rw-r--r--eval/src/vespa/eval/eval/function.h23
-rw-r--r--eval/src/vespa/eval/eval/llvm/compile_cache.cpp88
-rw-r--r--eval/src/vespa/eval/eval/llvm/compile_cache.h58
-rw-r--r--eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp12
-rw-r--r--eval/src/vespa/eval/eval/llvm/llvm_wrapper.h2
-rw-r--r--eval/src/vespa/eval/eval/make_tensor_function.cpp4
-rw-r--r--eval/src/vespa/eval/eval/tensor_nodes.h22
-rw-r--r--eval/src/vespa/eval/eval/test/eval_fixture.cpp8
-rw-r--r--eval/src/vespa/eval/eval/test/eval_fixture.h26
-rw-r--r--eval/src/vespa/eval/eval/test/tensor_conformance.cpp24
25 files changed, 785 insertions, 643 deletions
diff --git a/eval/src/apps/eval_expr/eval_expr.cpp b/eval/src/apps/eval_expr/eval_expr.cpp
index afddec40e48..4e6fec926e7 100644
--- a/eval/src/apps/eval_expr/eval_expr.cpp
+++ b/eval/src/apps/eval_expr/eval_expr.cpp
@@ -14,12 +14,12 @@ int main(int argc, char **argv) {
fprintf(stderr, " quote the expression to make it a single parameter\n");
return 1;
}
- Function function = Function::parse({}, argv[1]);
- if (function.has_error()) {
- fprintf(stderr, "expression error: %s\n", function.get_error().c_str());
+ auto function = Function::parse({}, argv[1]);
+ if (function->has_error()) {
+ fprintf(stderr, "expression error: %s\n", function->get_error().c_str());
return 1;
}
- InterpretedFunction interpreted(SimpleTensorEngine::ref(), function, NodeTypes());
+ InterpretedFunction interpreted(SimpleTensorEngine::ref(), *function, NodeTypes());
InterpretedFunction::Context ctx(interpreted);
SimpleParams params({});
const Value &result = interpreted.eval(ctx, params);
diff --git a/eval/src/apps/tensor_conformance/tensor_conformance.cpp b/eval/src/apps/tensor_conformance/tensor_conformance.cpp
index 1a760a436b6..59fe2960dbb 100644
--- a/eval/src/apps/tensor_conformance/tensor_conformance.cpp
+++ b/eval/src/apps/tensor_conformance/tensor_conformance.cpp
@@ -89,20 +89,20 @@ std::vector<ValueType> get_types(const std::vector<Value::UP> &param_values) {
}
TensorSpec eval_expr(const Inspector &test, const TensorEngine &engine, bool typed) {
- Function fun = Function::parse(test["expression"].asString().make_string());
+ auto fun = Function::parse(test["expression"].asString().make_string());
std::vector<Value::UP> param_values;
std::vector<Value::CREF> param_refs;
- for (size_t i = 0; i < fun.num_params(); ++i) {
- param_values.emplace_back(engine.from_spec(extract_value(test["inputs"][fun.param_name(i)])));
+ for (size_t i = 0; i < fun->num_params(); ++i) {
+ param_values.emplace_back(engine.from_spec(extract_value(test["inputs"][fun->param_name(i)])));
param_refs.emplace_back(*param_values.back());
}
- NodeTypes types = typed ? NodeTypes(fun, get_types(param_values)) : NodeTypes();
- InterpretedFunction ifun(engine, fun, types);
+ NodeTypes types = typed ? NodeTypes(*fun, get_types(param_values)) : NodeTypes();
+ InterpretedFunction ifun(engine, *fun, types);
InterpretedFunction::Context ctx(ifun);
SimpleObjectParams params(param_refs);
const Value &result = ifun.eval(ctx, params);
if (typed) {
- ASSERT_EQUAL(result.type(), types.get_type(fun.root()));
+ ASSERT_EQUAL(result.type(), types.get_type(fun->root()));
}
return engine.to_spec(result);
}
@@ -236,12 +236,12 @@ struct TestSpec {
const auto &my_expression = test["expression"];
ASSERT_TRUE(my_expression.valid());
expression = my_expression.asString().make_string();
- Function fun = Function::parse(expression);
- ASSERT_TRUE(!fun.has_error());
- ASSERT_EQUAL(fun.num_params(), test["inputs"].fields());
- for (size_t i = 0; i < fun.num_params(); ++i) {
+ auto fun = Function::parse(expression);
+ ASSERT_TRUE(!fun->has_error());
+ ASSERT_EQUAL(fun->num_params(), test["inputs"].fields());
+ for (size_t i = 0; i < fun->num_params(); ++i) {
TEST_STATE(make_string("input #%zu", i).c_str());
- const auto &my_input = test["inputs"][fun.param_name(i)];
+ const auto &my_input = test["inputs"][fun->param_name(i)];
ASSERT_TRUE(my_input.valid());
inputs.push_back(extract_value(my_input));
}
diff --git a/eval/src/tests/eval/compile_cache/compile_cache_test.cpp b/eval/src/tests/eval/compile_cache/compile_cache_test.cpp
index 6796d498719..01a29b9b9e8 100644
--- a/eval/src/tests/eval/compile_cache/compile_cache_test.cpp
+++ b/eval/src/tests/eval/compile_cache/compile_cache_test.cpp
@@ -3,57 +3,54 @@
#include <vespa/eval/eval/llvm/compile_cache.h>
#include <vespa/eval/eval/key_gen.h>
#include <vespa/eval/eval/test/eval_spec.h>
+#include <vespa/vespalib/util/time.h>
+#include <vespa/vespalib/util/threadstackexecutor.h>
+#include <thread>
#include <set>
+using namespace vespalib;
using namespace vespalib::eval;
//-----------------------------------------------------------------------------
TEST("require that parameter passing selection affects function key") {
- EXPECT_NOT_EQUAL(gen_key(Function::parse("a+b"), PassParams::SEPARATE),
- gen_key(Function::parse("a+b"), PassParams::ARRAY));
+ EXPECT_NOT_EQUAL(gen_key(*Function::parse("a+b"), PassParams::SEPARATE),
+ gen_key(*Function::parse("a+b"), PassParams::ARRAY));
}
TEST("require that the number of parameters affects function key") {
- EXPECT_NOT_EQUAL(gen_key(Function::parse({"a", "b"}, "a+b"), PassParams::SEPARATE),
- gen_key(Function::parse({"a", "b", "c"}, "a+b"), PassParams::SEPARATE));
- EXPECT_NOT_EQUAL(gen_key(Function::parse({"a", "b"}, "a+b"), PassParams::ARRAY),
- gen_key(Function::parse({"a", "b", "c"}, "a+b"), PassParams::ARRAY));
+ EXPECT_NOT_EQUAL(gen_key(*Function::parse({"a", "b"}, "a+b"), PassParams::SEPARATE),
+ gen_key(*Function::parse({"a", "b", "c"}, "a+b"), PassParams::SEPARATE));
+ EXPECT_NOT_EQUAL(gen_key(*Function::parse({"a", "b"}, "a+b"), PassParams::ARRAY),
+ gen_key(*Function::parse({"a", "b", "c"}, "a+b"), PassParams::ARRAY));
}
TEST("require that implicit and explicit parameters give the same function key") {
- EXPECT_EQUAL(gen_key(Function::parse({"a", "b"}, "a+b"), PassParams::SEPARATE),
- gen_key(Function::parse("a+b"), PassParams::SEPARATE));
- EXPECT_EQUAL(gen_key(Function::parse({"a", "b"}, "a+b"), PassParams::ARRAY),
- gen_key(Function::parse("a+b"), PassParams::ARRAY));
+ EXPECT_EQUAL(gen_key(*Function::parse({"a", "b"}, "a+b"), PassParams::SEPARATE),
+ gen_key(*Function::parse("a+b"), PassParams::SEPARATE));
+ EXPECT_EQUAL(gen_key(*Function::parse({"a", "b"}, "a+b"), PassParams::ARRAY),
+ gen_key(*Function::parse("a+b"), PassParams::ARRAY));
}
TEST("require that symbol names does not affect function key") {
- EXPECT_EQUAL(gen_key(Function::parse("a+b"), PassParams::SEPARATE),
- gen_key(Function::parse("x+y"), PassParams::SEPARATE));
- EXPECT_EQUAL(gen_key(Function::parse("a+b"), PassParams::ARRAY),
- gen_key(Function::parse("x+y"), PassParams::ARRAY));
-}
-
-TEST("require that let bind names does not affect function key") {
- EXPECT_EQUAL(gen_key(Function::parse("let(a,1,a+a)"), PassParams::SEPARATE),
- gen_key(Function::parse("let(b,1,b+b)"), PassParams::SEPARATE));
- EXPECT_EQUAL(gen_key(Function::parse("let(a,1,a+a)"), PassParams::ARRAY),
- gen_key(Function::parse("let(b,1,b+b)"), PassParams::ARRAY));
+ EXPECT_EQUAL(gen_key(*Function::parse("a+b"), PassParams::SEPARATE),
+ gen_key(*Function::parse("x+y"), PassParams::SEPARATE));
+ EXPECT_EQUAL(gen_key(*Function::parse("a+b"), PassParams::ARRAY),
+ gen_key(*Function::parse("x+y"), PassParams::ARRAY));
}
TEST("require that different values give different function keys") {
- EXPECT_NOT_EQUAL(gen_key(Function::parse("1"), PassParams::SEPARATE),
- gen_key(Function::parse("2"), PassParams::SEPARATE));
- EXPECT_NOT_EQUAL(gen_key(Function::parse("1"), PassParams::ARRAY),
- gen_key(Function::parse("2"), PassParams::ARRAY));
+ EXPECT_NOT_EQUAL(gen_key(*Function::parse("1"), PassParams::SEPARATE),
+ gen_key(*Function::parse("2"), PassParams::SEPARATE));
+ EXPECT_NOT_EQUAL(gen_key(*Function::parse("1"), PassParams::ARRAY),
+ gen_key(*Function::parse("2"), PassParams::ARRAY));
}
TEST("require that different strings give different function keys") {
- EXPECT_NOT_EQUAL(gen_key(Function::parse("\"a\""), PassParams::SEPARATE),
- gen_key(Function::parse("\"b\""), PassParams::SEPARATE));
- EXPECT_NOT_EQUAL(gen_key(Function::parse("\"a\""), PassParams::ARRAY),
- gen_key(Function::parse("\"b\""), PassParams::ARRAY));
+ EXPECT_NOT_EQUAL(gen_key(*Function::parse("\"a\""), PassParams::SEPARATE),
+ gen_key(*Function::parse("\"b\""), PassParams::SEPARATE));
+ EXPECT_NOT_EQUAL(gen_key(*Function::parse("\"a\""), PassParams::ARRAY),
+ gen_key(*Function::parse("\"b\""), PassParams::ARRAY));
}
//-----------------------------------------------------------------------------
@@ -69,11 +66,11 @@ struct CheckKeys : test::EvalSpec::EvalTest {
virtual void next_expression(const std::vector<vespalib::string> &param_names,
const vespalib::string &expression) override
{
- Function function = Function::parse(param_names, expression);
- if (!CompiledFunction::detect_issues(function)) {
- if (check_key(gen_key(function, PassParams::ARRAY)) ||
- check_key(gen_key(function, PassParams::SEPARATE)) ||
- check_key(gen_key(function, PassParams::LAZY)))
+ auto function = Function::parse(param_names, expression);
+ if (!CompiledFunction::detect_issues(*function)) {
+ if (check_key(gen_key(*function, PassParams::ARRAY)) ||
+ check_key(gen_key(*function, PassParams::SEPARATE)) ||
+ check_key(gen_key(*function, PassParams::LAZY)))
{
failed = true;
fprintf(stderr, "key collision for: %s\n", expression.c_str());
@@ -107,33 +104,33 @@ TEST("require that cache is initially empty") {
}
TEST("require that unused functions are evicted from the cache") {
- CompileCache::Token::UP token_a = CompileCache::compile(Function::parse("x+y"), PassParams::ARRAY);
+ CompileCache::Token::UP token_a = CompileCache::compile(*Function::parse("x+y"), PassParams::ARRAY);
TEST_DO(verify_cache(1, 1));
token_a.reset();
TEST_DO(verify_cache(0, 0));
}
TEST("require that agents can have separate functions in the cache") {
- CompileCache::Token::UP token_a = CompileCache::compile(Function::parse("x+y"), PassParams::ARRAY);
- CompileCache::Token::UP token_b = CompileCache::compile(Function::parse("x*y"), PassParams::ARRAY);
+ CompileCache::Token::UP token_a = CompileCache::compile(*Function::parse("x+y"), PassParams::ARRAY);
+ CompileCache::Token::UP token_b = CompileCache::compile(*Function::parse("x*y"), PassParams::ARRAY);
TEST_DO(verify_cache(2, 2));
}
TEST("require that agents can share functions in the cache") {
- CompileCache::Token::UP token_a = CompileCache::compile(Function::parse("x+y"), PassParams::ARRAY);
- CompileCache::Token::UP token_b = CompileCache::compile(Function::parse("x+y"), PassParams::ARRAY);
+ CompileCache::Token::UP token_a = CompileCache::compile(*Function::parse("x+y"), PassParams::ARRAY);
+ CompileCache::Token::UP token_b = CompileCache::compile(*Function::parse("x+y"), PassParams::ARRAY);
TEST_DO(verify_cache(1, 2));
}
TEST("require that cache usage works") {
TEST_DO(verify_cache(0, 0));
- CompileCache::Token::UP token_a = CompileCache::compile(Function::parse("x+y"), PassParams::SEPARATE);
+ CompileCache::Token::UP token_a = CompileCache::compile(*Function::parse("x+y"), PassParams::SEPARATE);
EXPECT_EQUAL(5.0, token_a->get().get_function<2>()(2.0, 3.0));
TEST_DO(verify_cache(1, 1));
- CompileCache::Token::UP token_b = CompileCache::compile(Function::parse("x*y"), PassParams::SEPARATE);
+ CompileCache::Token::UP token_b = CompileCache::compile(*Function::parse("x*y"), PassParams::SEPARATE);
EXPECT_EQUAL(6.0, token_b->get().get_function<2>()(2.0, 3.0));
TEST_DO(verify_cache(2, 2));
- CompileCache::Token::UP token_c = CompileCache::compile(Function::parse("x+y"), PassParams::SEPARATE);
+ CompileCache::Token::UP token_c = CompileCache::compile(*Function::parse("x+y"), PassParams::SEPARATE);
EXPECT_EQUAL(5.0, token_c->get().get_function<2>()(2.0, 3.0));
TEST_DO(verify_cache(2, 3));
token_a.reset();
@@ -144,6 +141,78 @@ TEST("require that cache usage works") {
TEST_DO(verify_cache(0, 0));
}
+struct CompileCheck : test::EvalSpec::EvalTest {
+ struct Entry {
+ CompileCache::Token::UP fun;
+ std::vector<double> params;
+ double expect;
+ Entry(CompileCache::Token::UP fun_in, const std::vector<double> &params_in, double expect_in)
+ : fun(std::move(fun_in)), params(params_in), expect(expect_in) {}
+ };
+ std::vector<Entry> list;
+ void next_expression(const std::vector<vespalib::string> &,
+ const vespalib::string &) override {}
+ void handle_case(const std::vector<vespalib::string> &param_names,
+ const std::vector<double> &param_values,
+ const vespalib::string &expression,
+ double expected_result) override
+ {
+ auto function = Function::parse(param_names, expression);
+ ASSERT_TRUE(!function->has_error());
+ bool has_issues = CompiledFunction::detect_issues(*function);
+ if (!has_issues) {
+ list.emplace_back(CompileCache::compile(*function, PassParams::ARRAY), param_values, expected_result);
+ }
+ }
+ void verify() {
+ for (const Entry &entry: list) {
+ auto fun = entry.fun->get().get_function();
+ if (std::isnan(entry.expect)) {
+ EXPECT_TRUE(std::isnan(fun(&entry.params[0])));
+ } else {
+ EXPECT_EQUAL(fun(&entry.params[0]), entry.expect);
+ }
+ }
+ }
+};
+
+TEST_F("compile sequentially, then run all conformance tests", test::EvalSpec()) {
+ f1.add_all_cases();
+ for (size_t i = 0; i < 4; ++i) {
+ CompileCheck test;
+ auto t0 = steady_clock::now();
+ f1.each_case(test);
+ auto t1 = steady_clock::now();
+ auto t2 = steady_clock::now();
+ test.verify();
+ auto t3 = steady_clock::now();
+ fprintf(stderr, "sequential (run %zu): setup: %zu ms, wait: %zu ms, verify: %zu us, total: %zu ms\n",
+ i, count_ms(t1 - t0), count_ms(t2 - t1), count_us(t3 - t2), count_ms(t3 - t0));
+ }
+}
+
+TEST_FF("compile concurrently (8 threads), then run all conformance tests", test::EvalSpec(), TimeBomb(60)) {
+ f1.add_all_cases();
+ ThreadStackExecutor executor(8, 256*1024);
+ CompileCache::attach_executor(executor);
+ while (executor.num_idle_workers() < 8) {
+ std::this_thread::sleep_for(1ms);
+ }
+ for (size_t i = 0; i < 4; ++i) {
+ CompileCheck test;
+ auto t0 = steady_clock::now();
+ f1.each_case(test);
+ auto t1 = steady_clock::now();
+ executor.sync();
+ auto t2 = steady_clock::now();
+ test.verify();
+ auto t3 = steady_clock::now();
+ fprintf(stderr, "concurrent (run %zu): setup: %zu ms, wait: %zu ms, verify: %zu us, total: %zu ms\n",
+ i, count_ms(t1 - t0), count_ms(t2 - t1), count_us(t3 - t2), count_ms(t3 - t0));
+ }
+ CompileCache::detach_executor();
+}
+
//-----------------------------------------------------------------------------
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp
index 1d5a6929083..b01c849da1e 100644
--- a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp
+++ b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp
@@ -18,7 +18,7 @@ std::vector<vespalib::string> params_10({"p1", "p2", "p3", "p4", "p5", "p6", "p7
const char *expr_10 = "p1 + p2 + p3 + p4 + p5 + p6 + p7 + p8 + p9 + p10";
TEST("require that separate parameter passing works") {
- CompiledFunction cf_10(Function::parse(params_10, expr_10), PassParams::SEPARATE);
+ CompiledFunction cf_10(*Function::parse(params_10, expr_10), PassParams::SEPARATE);
auto fun_10 = cf_10.get_function<10>();
EXPECT_EQUAL(10.0, fun_10(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0));
EXPECT_EQUAL(50.0, fun_10(5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0));
@@ -27,7 +27,7 @@ TEST("require that separate parameter passing works") {
}
TEST("require that array parameter passing works") {
- CompiledFunction arr_cf(Function::parse(params_10, expr_10), PassParams::ARRAY);
+ CompiledFunction arr_cf(*Function::parse(params_10, expr_10), PassParams::ARRAY);
auto arr_fun = arr_cf.get_function();
EXPECT_EQUAL(10.0, arr_fun(&std::vector<double>({1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0})[0]));
EXPECT_EQUAL(50.0, arr_fun(&std::vector<double>({5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0})[0]));
@@ -38,7 +38,7 @@ TEST("require that array parameter passing works") {
double my_resolve(void *ctx, size_t idx) { return ((double *)ctx)[idx]; }
TEST("require that lazy parameter passing works") {
- CompiledFunction lazy_cf(Function::parse(params_10, expr_10), PassParams::LAZY);
+ CompiledFunction lazy_cf(*Function::parse(params_10, expr_10), PassParams::LAZY);
auto lazy_fun = lazy_cf.get_lazy_function();
EXPECT_EQUAL(10.0, lazy_fun(my_resolve, &std::vector<double>({1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0})[0]));
EXPECT_EQUAL(50.0, lazy_fun(my_resolve, &std::vector<double>({5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0})[0]));
@@ -79,10 +79,10 @@ struct MyEvalTest : test::EvalSpec::EvalTest {
virtual void next_expression(const std::vector<vespalib::string> &param_names,
const vespalib::string &expression) override
{
- Function function = Function::parse(param_names, expression);
- ASSERT_TRUE(!function.has_error());
+ auto function = Function::parse(param_names, expression);
+ ASSERT_TRUE(!function->has_error());
bool is_supported = !is_unsupported(expression);
- bool has_issues = CompiledFunction::detect_issues(function);
+ bool has_issues = CompiledFunction::detect_issues(*function);
if (is_supported == has_issues) {
const char *supported_str = is_supported ? "supported" : "not supported";
const char *issues_str = has_issues ? "has issues" : "does not have issues";
@@ -96,12 +96,12 @@ struct MyEvalTest : test::EvalSpec::EvalTest {
const vespalib::string &expression,
double expected_result) override
{
- Function function = Function::parse(param_names, expression);
- ASSERT_TRUE(!function.has_error());
+ auto function = Function::parse(param_names, expression);
+ ASSERT_TRUE(!function->has_error());
bool is_supported = !is_unsupported(expression);
- bool has_issues = CompiledFunction::detect_issues(function);
+ bool has_issues = CompiledFunction::detect_issues(*function);
if (is_supported && !has_issues) {
- CompiledFunction cfun(function, PassParams::ARRAY);
+ CompiledFunction cfun(*function, PassParams::ARRAY);
auto fun = cfun.get_function();
ASSERT_EQUAL(cfun.num_params(), param_values.size());
double result = fun(&param_values[0]);
@@ -135,9 +135,9 @@ TEST("require that large (plugin) set membership checks work") {
for(size_t i = 1; i <= 100; ++i) {
my_in->add_entry(std::make_unique<nodes::Number>(i));
}
- Function my_fun(std::move(my_in), {"a"});
- CompiledFunction cf(my_fun, PassParams::SEPARATE);
- CompiledFunction arr_cf(my_fun, PassParams::ARRAY);
+ auto my_fun = Function::create(std::move(my_in), {"a"});
+ CompiledFunction cf(*my_fun, PassParams::SEPARATE);
+ CompiledFunction arr_cf(*my_fun, PassParams::ARRAY);
auto fun = cf.get_function<1>();
auto arr_fun = arr_cf.get_function();
for (double value = 0.5; value <= 100.5; value += 0.5) {
@@ -160,7 +160,7 @@ CompiledFunction pass_fun(CompiledFunction cf) {
}
TEST("require that compiled expression can be passed (moved) around") {
- CompiledFunction cf(Function::parse("a+b"), PassParams::SEPARATE);
+ CompiledFunction cf(*Function::parse("a+b"), PassParams::SEPARATE);
auto fun = cf.get_function<2>();
EXPECT_EQUAL(4.0, fun(2.0, 2.0));
CompiledFunction cf2 = pass_fun(std::move(cf));
@@ -171,16 +171,16 @@ TEST("require that compiled expression can be passed (moved) around") {
}
TEST("require that expressions with constant sub-expressions evaluate correctly") {
- CompiledFunction cf(Function::parse("if(1,2,10)+a+b+max(1,2)/1"), PassParams::SEPARATE);
+ CompiledFunction cf(*Function::parse("if(1,2,10)+a+b+max(1,2)/1"), PassParams::SEPARATE);
auto fun = cf.get_function<2>();
EXPECT_EQUAL(7.0, fun(1.0, 2.0));
EXPECT_EQUAL(11.0, fun(3.0, 4.0));
}
TEST("dump ir code to verify lazy casting") {
- Function function = Function::parse({"a", "b"}, "12==2+if(a==3&&a<10||b,10,5)");
+ auto function = Function::parse({"a", "b"}, "12==2+if(a==3&&a<10||b,10,5)");
LLVMWrapper wrapper;
- size_t id = wrapper.make_function(function.num_params(), PassParams::SEPARATE, function.root(), {});
+ size_t id = wrapper.make_function(function->num_params(), PassParams::SEPARATE, function->root(), {});
wrapper.compile(llvm::dbgs()); // dump module before compiling it
using fun_type = double (*)(double, double);
fun_type fun = (fun_type) wrapper.get_function_address(id);
@@ -189,30 +189,32 @@ TEST("dump ir code to verify lazy casting") {
EXPECT_EQUAL(1.0, fun(3.0, 0.0));
}
-TEST_MT("require that multithreaded compilation works", 64) {
- {
- CompiledFunction cf(Function::parse({"x", "y", "z", "w"}, "((x+1)*(y-1))/((z+1)/(w-1))"),
- PassParams::SEPARATE);
- auto fun = cf.get_function<4>();
- EXPECT_EQUAL(1.0, fun(0.0, 2.0, 0.0, 2.0));
- }
- {
- CompiledFunction cf(Function::parse({"x", "y", "z", "w"}, "((x+1)*(y-1))/((z+1)/(w-1))"),
- PassParams::SEPARATE);
- auto fun = cf.get_function<4>();
- EXPECT_EQUAL(4.0, fun(1.0, 3.0, 0.0, 2.0));
- }
- {
- CompiledFunction cf(Function::parse({"x", "y", "z", "w"}, "((x+1)*(y-1))/((z+1)/(w-1))"),
- PassParams::SEPARATE);
- auto fun = cf.get_function<4>();
- EXPECT_EQUAL(2.0, fun(1.0, 3.0, 1.0, 2.0));
- }
- {
- CompiledFunction cf(Function::parse({"x", "y", "z", "w"}, "((x+1)*(y-1))/((z+1)/(w-1))"),
- PassParams::SEPARATE);
- auto fun = cf.get_function<4>();
- EXPECT_EQUAL(8.0, fun(1.0, 3.0, 1.0, 5.0));
+TEST_MT("require that multithreaded compilation works", 32) {
+ for (size_t i = 0; i < 16; ++i) {
+ {
+ CompiledFunction cf(*Function::parse({"x", "y", "z", "w"}, "((x+1)*(y-1))/((z+1)/(w-1))"),
+ PassParams::SEPARATE);
+ auto fun = cf.get_function<4>();
+ EXPECT_EQUAL(1.0, fun(0.0, 2.0, 0.0, 2.0));
+ }
+ {
+ CompiledFunction cf(*Function::parse({"x", "y", "z", "w"}, "((x+1)*(y-1))/((z+1)/(w-1))"),
+ PassParams::SEPARATE);
+ auto fun = cf.get_function<4>();
+ EXPECT_EQUAL(4.0, fun(1.0, 3.0, 0.0, 2.0));
+ }
+ {
+ CompiledFunction cf(*Function::parse({"x", "y", "z", "w"}, "((x+1)*(y-1))/((z+1)/(w-1))"),
+ PassParams::SEPARATE);
+ auto fun = cf.get_function<4>();
+ EXPECT_EQUAL(2.0, fun(1.0, 3.0, 1.0, 2.0));
+ }
+ {
+ CompiledFunction cf(*Function::parse({"x", "y", "z", "w"}, "((x+1)*(y-1))/((z+1)/(w-1))"),
+ PassParams::SEPARATE);
+ auto fun = cf.get_function<4>();
+ EXPECT_EQUAL(8.0, fun(1.0, 3.0, 1.0, 5.0));
+ }
}
}
@@ -221,12 +223,12 @@ TEST_MT("require that multithreaded compilation works", 64) {
TEST("require that function issues can be detected") {
auto simple = Function::parse("a+b");
auto complex = Function::parse("join(a,b,f(a,b)(a+b))");
- EXPECT_FALSE(simple.has_error());
- EXPECT_FALSE(complex.has_error());
- EXPECT_FALSE(CompiledFunction::detect_issues(simple));
- EXPECT_TRUE(CompiledFunction::detect_issues(complex));
+ EXPECT_FALSE(simple->has_error());
+ EXPECT_FALSE(complex->has_error());
+ EXPECT_FALSE(CompiledFunction::detect_issues(*simple));
+ EXPECT_TRUE(CompiledFunction::detect_issues(*complex));
std::cerr << "Example function issues:" << std::endl
- << CompiledFunction::detect_issues(complex).list
+ << CompiledFunction::detect_issues(*complex).list
<< std::endl;
}
diff --git a/eval/src/tests/eval/function/function_test.cpp b/eval/src/tests/eval/function/function_test.cpp
index 7cfe54b4b95..fef0ed35668 100644
--- a/eval/src/tests/eval/function/function_test.cpp
+++ b/eval/src/tests/eval/function/function_test.cpp
@@ -70,59 +70,59 @@ void verify_operator_binding_order(std::initializer_list<OperatorLayer> layers)
bool verify_string(const vespalib::string &str, const vespalib::string &expr) {
bool ok = true;
- ok &= EXPECT_EQUAL(str, as_string(Function::parse(params, expr)));
- ok &= EXPECT_EQUAL(expr, Function::parse(params, expr).dump());
+ ok &= EXPECT_EQUAL(str, as_string(*Function::parse(params, expr)));
+ ok &= EXPECT_EQUAL(expr, Function::parse(params, expr)->dump());
return ok;
}
void verify_error(const vespalib::string &expr, const vespalib::string &expected_error) {
- Function function = Function::parse(params, expr);
- EXPECT_TRUE(function.has_error());
- EXPECT_EQUAL(expected_error, function.get_error());
+ auto function = Function::parse(params, expr);
+ EXPECT_TRUE(function->has_error());
+ EXPECT_EQUAL(expected_error, function->get_error());
}
void verify_parse(const vespalib::string &expr, const vespalib::string &expect) {
- Function function = Function::parse(expr);
- EXPECT_TRUE(!function.has_error());
- EXPECT_EQUAL(function.dump_as_lambda(), expect);
+ auto function = Function::parse(expr);
+ EXPECT_TRUE(!function->has_error());
+ EXPECT_EQUAL(function->dump_as_lambda(), expect);
}
TEST("require that scientific numbers can be parsed") {
- EXPECT_EQUAL(1.0, as_number(Function::parse(params, "1")));
- EXPECT_EQUAL(2.5, as_number(Function::parse(params, "2.5")));
- EXPECT_EQUAL(100.0, as_number(Function::parse(params, "100")));
- EXPECT_EQUAL(0.01, as_number(Function::parse(params, "0.01")));
- EXPECT_EQUAL(1.05e5, as_number(Function::parse(params, "1.05e5")));
- EXPECT_EQUAL(3e7, as_number(Function::parse(params, "3e7")));
- EXPECT_EQUAL(1.05e5, as_number(Function::parse(params, "1.05e+5")));
- EXPECT_EQUAL(3e7, as_number(Function::parse(params, "3e+7")));
- EXPECT_EQUAL(1.05e-5, as_number(Function::parse(params, "1.05e-5")));
- EXPECT_EQUAL(3e-7, as_number(Function::parse(params, "3e-7")));
- EXPECT_EQUAL(1.05e5, as_number(Function::parse(params, "1.05E5")));
- EXPECT_EQUAL(3e7, as_number(Function::parse(params, "3E7")));
- EXPECT_EQUAL(1.05e5, as_number(Function::parse(params, "1.05E+5")));
- EXPECT_EQUAL(3e7, as_number(Function::parse(params, "3E+7")));
- EXPECT_EQUAL(1.05e-5, as_number(Function::parse(params, "1.05E-5")));
- EXPECT_EQUAL(3e-7, as_number(Function::parse(params, "3E-7")));
+ EXPECT_EQUAL(1.0, as_number(*Function::parse(params, "1")));
+ EXPECT_EQUAL(2.5, as_number(*Function::parse(params, "2.5")));
+ EXPECT_EQUAL(100.0, as_number(*Function::parse(params, "100")));
+ EXPECT_EQUAL(0.01, as_number(*Function::parse(params, "0.01")));
+ EXPECT_EQUAL(1.05e5, as_number(*Function::parse(params, "1.05e5")));
+ EXPECT_EQUAL(3e7, as_number(*Function::parse(params, "3e7")));
+ EXPECT_EQUAL(1.05e5, as_number(*Function::parse(params, "1.05e+5")));
+ EXPECT_EQUAL(3e7, as_number(*Function::parse(params, "3e+7")));
+ EXPECT_EQUAL(1.05e-5, as_number(*Function::parse(params, "1.05e-5")));
+ EXPECT_EQUAL(3e-7, as_number(*Function::parse(params, "3e-7")));
+ EXPECT_EQUAL(1.05e5, as_number(*Function::parse(params, "1.05E5")));
+ EXPECT_EQUAL(3e7, as_number(*Function::parse(params, "3E7")));
+ EXPECT_EQUAL(1.05e5, as_number(*Function::parse(params, "1.05E+5")));
+ EXPECT_EQUAL(3e7, as_number(*Function::parse(params, "3E+7")));
+ EXPECT_EQUAL(1.05e-5, as_number(*Function::parse(params, "1.05E-5")));
+ EXPECT_EQUAL(3e-7, as_number(*Function::parse(params, "3E-7")));
}
TEST("require that number parsing does not eat +/- operators") {
- EXPECT_EQUAL("(((1+2)+3)+4)", Function::parse(params, "1+2+3+4").dump());
- EXPECT_EQUAL("(((1-2)-3)-4)", Function::parse(params, "1-2-3-4").dump());
- EXPECT_EQUAL("(((1+x)+3)+y)", Function::parse(params, "1+x+3+y").dump());
- EXPECT_EQUAL("(((1-x)-3)-y)", Function::parse(params, "1-x-3-y").dump());
+ EXPECT_EQUAL("(((1+2)+3)+4)", Function::parse(params, "1+2+3+4")->dump());
+ EXPECT_EQUAL("(((1-2)-3)-4)", Function::parse(params, "1-2-3-4")->dump());
+ EXPECT_EQUAL("(((1+x)+3)+y)", Function::parse(params, "1+x+3+y")->dump());
+ EXPECT_EQUAL("(((1-x)-3)-y)", Function::parse(params, "1-x-3-y")->dump());
}
TEST("require that symbols can be parsed") {
- EXPECT_EQUAL("x", Function::parse(params, "x").dump());
- EXPECT_EQUAL("y", Function::parse(params, "y").dump());
- EXPECT_EQUAL("z", Function::parse(params, "z").dump());
+ EXPECT_EQUAL("x", Function::parse(params, "x")->dump());
+ EXPECT_EQUAL("y", Function::parse(params, "y")->dump());
+ EXPECT_EQUAL("z", Function::parse(params, "z")->dump());
}
TEST("require that parenthesis can be parsed") {
- EXPECT_EQUAL("x", Function::parse(params, "(x)").dump());
- EXPECT_EQUAL("x", Function::parse(params, "((x))").dump());
- EXPECT_EQUAL("x", Function::parse(params, "(((x)))").dump());
+ EXPECT_EQUAL("x", Function::parse(params, "(x)")->dump());
+ EXPECT_EQUAL("x", Function::parse(params, "((x))")->dump());
+ EXPECT_EQUAL("x", Function::parse(params, "(((x)))")->dump());
}
TEST("require that strings are parsed and dumped correctly") {
@@ -139,31 +139,31 @@ TEST("require that strings are parsed and dumped correctly") {
vespalib::string raw_expr = vespalib::make_string("\"%c\"", c);
vespalib::string hex_expr = vespalib::make_string("\"\\x%02x\"", c);
vespalib::string raw_str = vespalib::make_string("%c", c);
- EXPECT_EQUAL(raw_str, as_string(Function::parse(params, hex_expr)));
+ EXPECT_EQUAL(raw_str, as_string(*Function::parse(params, hex_expr)));
if (c != 0 && c != '\"' && c != '\\') {
- EXPECT_EQUAL(raw_str, as_string(Function::parse(params, raw_expr)));
+ EXPECT_EQUAL(raw_str, as_string(*Function::parse(params, raw_expr)));
} else {
- EXPECT_TRUE(Function::parse(params, raw_expr).has_error());
+ EXPECT_TRUE(Function::parse(params, raw_expr)->has_error());
}
if (c == '\\') {
- EXPECT_EQUAL("\"\\\\\"", Function::parse(params, hex_expr).dump());
+ EXPECT_EQUAL("\"\\\\\"", Function::parse(params, hex_expr)->dump());
} else if (c == '\"') {
- EXPECT_EQUAL("\"\\\"\"", Function::parse(params, hex_expr).dump());
+ EXPECT_EQUAL("\"\\\"\"", Function::parse(params, hex_expr)->dump());
} else if (c == '\t') {
- EXPECT_EQUAL("\"\\t\"", Function::parse(params, hex_expr).dump());
+ EXPECT_EQUAL("\"\\t\"", Function::parse(params, hex_expr)->dump());
} else if (c == '\n') {
- EXPECT_EQUAL("\"\\n\"", Function::parse(params, hex_expr).dump());
+ EXPECT_EQUAL("\"\\n\"", Function::parse(params, hex_expr)->dump());
} else if (c == '\r') {
- EXPECT_EQUAL("\"\\r\"", Function::parse(params, hex_expr).dump());
+ EXPECT_EQUAL("\"\\r\"", Function::parse(params, hex_expr)->dump());
} else if (c == '\f') {
- EXPECT_EQUAL("\"\\f\"", Function::parse(params, hex_expr).dump());
+ EXPECT_EQUAL("\"\\f\"", Function::parse(params, hex_expr)->dump());
} else if ((c >= 32) && (c <= 126)) {
if (c >= 'a' && c <= 'z' && c != 't' && c != 'n' && c != 'r' && c != 'f') {
- EXPECT_TRUE(Function::parse(params, vespalib::make_string("\"\\%c\"", c)).has_error());
+ EXPECT_TRUE(Function::parse(params, vespalib::make_string("\"\\%c\"", c))->has_error());
}
- EXPECT_EQUAL(raw_expr, Function::parse(params, hex_expr).dump());
+ EXPECT_EQUAL(raw_expr, Function::parse(params, hex_expr)->dump());
} else {
- EXPECT_EQUAL(hex_expr, Function::parse(params, hex_expr).dump());
+ EXPECT_EQUAL(hex_expr, Function::parse(params, hex_expr)->dump());
}
}
}
@@ -173,36 +173,36 @@ TEST("require that free arrays cannot be parsed") {
}
TEST("require that negative values can be parsed") {
- EXPECT_EQUAL("-1", Function::parse(params, "-1").dump());
- EXPECT_EQUAL("1", Function::parse(params, "--1").dump());
- EXPECT_EQUAL("-1", Function::parse(params, " ( - ( - ( - ( (1) ) ) ) )").dump());
- EXPECT_EQUAL("-2.5", Function::parse(params, "-2.5").dump());
- EXPECT_EQUAL("-100", Function::parse(params, "-100").dump());
+ EXPECT_EQUAL("-1", Function::parse(params, "-1")->dump());
+ EXPECT_EQUAL("1", Function::parse(params, "--1")->dump());
+ EXPECT_EQUAL("-1", Function::parse(params, " ( - ( - ( - ( (1) ) ) ) )")->dump());
+ EXPECT_EQUAL("-2.5", Function::parse(params, "-2.5")->dump());
+ EXPECT_EQUAL("-100", Function::parse(params, "-100")->dump());
}
TEST("require that negative symbols can be parsed") {
- EXPECT_EQUAL("(-x)", Function::parse(params, "-x").dump());
- EXPECT_EQUAL("(-y)", Function::parse(params, "-y").dump());
- EXPECT_EQUAL("(-z)", Function::parse(params, "-z").dump());
- EXPECT_EQUAL("(-(-(-x)))", Function::parse(params, "---x").dump());
+ EXPECT_EQUAL("(-x)", Function::parse(params, "-x")->dump());
+ EXPECT_EQUAL("(-y)", Function::parse(params, "-y")->dump());
+ EXPECT_EQUAL("(-z)", Function::parse(params, "-z")->dump());
+ EXPECT_EQUAL("(-(-(-x)))", Function::parse(params, "---x")->dump());
}
TEST("require that not can be parsed") {
- EXPECT_EQUAL("(!x)", Function::parse(params, "!x").dump());
- EXPECT_EQUAL("(!(!x))", Function::parse(params, "!!x").dump());
- EXPECT_EQUAL("(!(!(!x)))", Function::parse(params, "!!!x").dump());
+ EXPECT_EQUAL("(!x)", Function::parse(params, "!x")->dump());
+ EXPECT_EQUAL("(!(!x))", Function::parse(params, "!!x")->dump());
+ EXPECT_EQUAL("(!(!(!x)))", Function::parse(params, "!!!x")->dump());
}
TEST("require that not/neg binds to next value") {
- EXPECT_EQUAL("((!(!(-(-x))))^z)", Function::parse(params, "!!--x^z").dump());
- EXPECT_EQUAL("((-(-(!(!x))))^z)", Function::parse(params, "--!!x^z").dump());
- EXPECT_EQUAL("((!(-(-(!x))))^z)", Function::parse(params, "!--!x^z").dump());
- EXPECT_EQUAL("((-(!(!(-x))))^z)", Function::parse(params, "-!!-x^z").dump());
+ EXPECT_EQUAL("((!(!(-(-x))))^z)", Function::parse(params, "!!--x^z")->dump());
+ EXPECT_EQUAL("((-(-(!(!x))))^z)", Function::parse(params, "--!!x^z")->dump());
+ EXPECT_EQUAL("((!(-(-(!x))))^z)", Function::parse(params, "!--!x^z")->dump());
+ EXPECT_EQUAL("((-(!(!(-x))))^z)", Function::parse(params, "-!!-x^z")->dump());
}
TEST("require that parenthesis resolves before not/neg") {
- EXPECT_EQUAL("(!(x^z))", Function::parse(params, "!(x^z)").dump());
- EXPECT_EQUAL("(-(x^z))", Function::parse(params, "-(x^z)").dump());
+ EXPECT_EQUAL("(!(x^z))", Function::parse(params, "!(x^z)")->dump());
+ EXPECT_EQUAL("(-(x^z))", Function::parse(params, "-(x^z)")->dump());
}
TEST("require that operators have appropriate binding order") {
@@ -216,48 +216,48 @@ TEST("require that operators have appropriate binding order") {
TEST("require that operators binding left are calculated left to right") {
EXPECT_TRUE(create_op("+")->order() == Operator::Order::LEFT);
- EXPECT_EQUAL("((x+y)+z)", Function::parse(params, "x+y+z").dump());
+ EXPECT_EQUAL("((x+y)+z)", Function::parse(params, "x+y+z")->dump());
}
TEST("require that operators binding right are calculated right to left") {
EXPECT_TRUE(create_op("^")->order() == Operator::Order::RIGHT);
- EXPECT_EQUAL("(x^(y^z))", Function::parse(params, "x^y^z").dump());
+ EXPECT_EQUAL("(x^(y^z))", Function::parse(params, "x^y^z")->dump());
}
TEST("require that operators with higher precedence are resolved first") {
EXPECT_TRUE(create_op("*")->priority() > create_op("+")->priority());
- EXPECT_EQUAL("(x+(y*z))", Function::parse(params, "x+y*z").dump());
- EXPECT_EQUAL("((x*y)+z)", Function::parse(params, "x*y+z").dump());
+ EXPECT_EQUAL("(x+(y*z))", Function::parse(params, "x+y*z")->dump());
+ EXPECT_EQUAL("((x*y)+z)", Function::parse(params, "x*y+z")->dump());
}
TEST("require that multi-level operator precedence resolving works") {
EXPECT_TRUE(create_op("^")->priority() > create_op("*")->priority());
EXPECT_TRUE(create_op("*")->priority() > create_op("+")->priority());
- EXPECT_EQUAL("(x+(y*(z^w)))", Function::parse(params, "x+y*z^w").dump());
- EXPECT_EQUAL("(x+((y^z)*w))", Function::parse(params, "x+y^z*w").dump());
- EXPECT_EQUAL("((x*y)+(z^w))", Function::parse(params, "x*y+z^w").dump());
- EXPECT_EQUAL("((x*(y^z))+w)", Function::parse(params, "x*y^z+w").dump());
- EXPECT_EQUAL("((x^y)+(z*w))", Function::parse(params, "x^y+z*w").dump());
- EXPECT_EQUAL("(((x^y)*z)+w)", Function::parse(params, "x^y*z+w").dump());
+ EXPECT_EQUAL("(x+(y*(z^w)))", Function::parse(params, "x+y*z^w")->dump());
+ EXPECT_EQUAL("(x+((y^z)*w))", Function::parse(params, "x+y^z*w")->dump());
+ EXPECT_EQUAL("((x*y)+(z^w))", Function::parse(params, "x*y+z^w")->dump());
+ EXPECT_EQUAL("((x*(y^z))+w)", Function::parse(params, "x*y^z+w")->dump());
+ EXPECT_EQUAL("((x^y)+(z*w))", Function::parse(params, "x^y+z*w")->dump());
+ EXPECT_EQUAL("(((x^y)*z)+w)", Function::parse(params, "x^y*z+w")->dump());
}
TEST("require that expressions are combined when parenthesis are closed") {
- EXPECT_EQUAL("((x+(y+z))+w)", Function::parse(params, "x+(y+z)+w").dump());
+ EXPECT_EQUAL("((x+(y+z))+w)", Function::parse(params, "x+(y+z)+w")->dump());
}
TEST("require that operators can not bind out of parenthesis") {
EXPECT_TRUE(create_op("*")->priority() > create_op("+")->priority());
- EXPECT_EQUAL("((x+y)*(x+z))", Function::parse(params, "(x+y)*(x+z)").dump());
+ EXPECT_EQUAL("((x+y)*(x+z))", Function::parse(params, "(x+y)*(x+z)")->dump());
}
TEST("require that set membership constructs can be parsed") {
- EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "x in [1,2,3]").dump());
- EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "x in [ 1 , 2 , 3 ] ").dump());
- EXPECT_EQUAL("(x in [-1,-2,-3])", Function::parse(params, "x in [-1,-2,-3]").dump());
- EXPECT_EQUAL("(x in [-1,-2,-3])", Function::parse(params, "x in [ - 1 , - 2 , - 3 ]").dump());
- EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "x in[1,2,3]").dump());
- EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "(x)in[1,2,3]").dump());
- EXPECT_EQUAL("(x in [\"a\",2,\"c\"])", Function::parse(params, "x in [\"a\",2,\"c\"]").dump());
+ EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "x in [1,2,3]")->dump());
+ EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "x in [ 1 , 2 , 3 ] ")->dump());
+ EXPECT_EQUAL("(x in [-1,-2,-3])", Function::parse(params, "x in [-1,-2,-3]")->dump());
+ EXPECT_EQUAL("(x in [-1,-2,-3])", Function::parse(params, "x in [ - 1 , - 2 , - 3 ]")->dump());
+ EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "x in[1,2,3]")->dump());
+ EXPECT_EQUAL("(x in [1,2,3])", Function::parse(params, "(x)in[1,2,3]")->dump());
+ EXPECT_EQUAL("(x in [\"a\",2,\"c\"])", Function::parse(params, "x in [\"a\",2,\"c\"]")->dump());
}
TEST("require that set membership entries must be array of strings/numbers") {
@@ -270,88 +270,88 @@ TEST("require that set membership entries must be array of strings/numbers") {
}
TEST("require that set membership binds to the next value") {
- EXPECT_EQUAL("((x in [1,2,3])^2)", Function::parse(params, "x in [1,2,3]^2").dump());
+ EXPECT_EQUAL("((x in [1,2,3])^2)", Function::parse(params, "x in [1,2,3]^2")->dump());
}
TEST("require that set membership binds to the left with appropriate precedence") {
- EXPECT_EQUAL("((x<y) in [1,2,3])", Function::parse(params, "x < y in [1,2,3]").dump());
- EXPECT_EQUAL("(x&&(y in [1,2,3]))", Function::parse(params, "x && y in [1,2,3]").dump());
+ EXPECT_EQUAL("((x<y) in [1,2,3])", Function::parse(params, "x < y in [1,2,3]")->dump());
+ EXPECT_EQUAL("(x&&(y in [1,2,3]))", Function::parse(params, "x && y in [1,2,3]")->dump());
}
TEST("require that function calls can be parsed") {
- EXPECT_EQUAL("min(max(x,y),sqrt(z))", Function::parse(params, "min(max(x,y),sqrt(z))").dump());
+ EXPECT_EQUAL("min(max(x,y),sqrt(z))", Function::parse(params, "min(max(x,y),sqrt(z))")->dump());
}
TEST("require that if expressions can be parsed") {
- EXPECT_EQUAL("if(x,y,z)", Function::parse(params, "if(x,y,z)").dump());
- EXPECT_EQUAL("if(x,y,z)", Function::parse(params, "if (x,y,z)").dump());
- EXPECT_EQUAL("if(x,y,z)", Function::parse(params, " if ( x , y , z ) ").dump());
- EXPECT_EQUAL("if(((x>1)&&(y<3)),(y+1),(z-1))", Function::parse(params, "if(x>1&&y<3,y+1,z-1)").dump());
- EXPECT_EQUAL("if(if(x,y,z),if(x,y,z),if(x,y,z))", Function::parse(params, "if(if(x,y,z),if(x,y,z),if(x,y,z))").dump());
- EXPECT_EQUAL("if(x,y,z,0.25)", Function::parse(params, "if(x,y,z,0.25)").dump());
- EXPECT_EQUAL("if(x,y,z,0.75)", Function::parse(params, "if(x,y,z,0.75)").dump());
+ EXPECT_EQUAL("if(x,y,z)", Function::parse(params, "if(x,y,z)")->dump());
+ EXPECT_EQUAL("if(x,y,z)", Function::parse(params, "if (x,y,z)")->dump());
+ EXPECT_EQUAL("if(x,y,z)", Function::parse(params, " if ( x , y , z ) ")->dump());
+ EXPECT_EQUAL("if(((x>1)&&(y<3)),(y+1),(z-1))", Function::parse(params, "if(x>1&&y<3,y+1,z-1)")->dump());
+ EXPECT_EQUAL("if(if(x,y,z),if(x,y,z),if(x,y,z))", Function::parse(params, "if(if(x,y,z),if(x,y,z),if(x,y,z))")->dump());
+ EXPECT_EQUAL("if(x,y,z,0.25)", Function::parse(params, "if(x,y,z,0.25)")->dump());
+ EXPECT_EQUAL("if(x,y,z,0.75)", Function::parse(params, "if(x,y,z,0.75)")->dump());
}
TEST("require that if probability can be inspected") {
- Function fun_1 = Function::parse("if(x,y,z,0.25)");
- auto if_1 = as<If>(fun_1.root());
+ auto fun_1 = Function::parse("if(x,y,z,0.25)");
+ auto if_1 = as<If>(fun_1->root());
ASSERT_TRUE(if_1);
EXPECT_EQUAL(0.25, if_1->p_true());
- Function fun_2 = Function::parse("if(x,y,z,0.75)");
- auto if_2 = as<If>(fun_2.root());
+ auto fun_2 = Function::parse("if(x,y,z,0.75)");
+ auto if_2 = as<If>(fun_2->root());
ASSERT_TRUE(if_2);
EXPECT_EQUAL(0.75, if_2->p_true());
}
TEST("require that symbols can be implicit") {
- EXPECT_EQUAL("x", Function::parse("x").dump());
- EXPECT_EQUAL("y", Function::parse("y").dump());
- EXPECT_EQUAL("z", Function::parse("z").dump());
+ EXPECT_EQUAL("x", Function::parse("x")->dump());
+ EXPECT_EQUAL("y", Function::parse("y")->dump());
+ EXPECT_EQUAL("z", Function::parse("z")->dump());
}
TEST("require that implicit parameters are picket up left to right") {
- Function fun1 = Function::parse("x+y+y");
- Function fun2 = Function::parse("y+y+x");
- EXPECT_EQUAL("((x+y)+y)", fun1.dump());
- EXPECT_EQUAL("((y+y)+x)", fun2.dump());
- ASSERT_EQUAL(2u, fun1.num_params());
- ASSERT_EQUAL(2u, fun2.num_params());
- EXPECT_EQUAL("x", fun1.param_name(0));
- EXPECT_EQUAL("x", fun2.param_name(1));
- EXPECT_EQUAL("y", fun1.param_name(1));
- EXPECT_EQUAL("y", fun2.param_name(0));
+ auto fun1 = Function::parse("x+y+y");
+ auto fun2 = Function::parse("y+y+x");
+ EXPECT_EQUAL("((x+y)+y)", fun1->dump());
+ EXPECT_EQUAL("((y+y)+x)", fun2->dump());
+ ASSERT_EQUAL(2u, fun1->num_params());
+ ASSERT_EQUAL(2u, fun2->num_params());
+ EXPECT_EQUAL("x", fun1->param_name(0));
+ EXPECT_EQUAL("x", fun2->param_name(1));
+ EXPECT_EQUAL("y", fun1->param_name(1));
+ EXPECT_EQUAL("y", fun2->param_name(0));
}
//-----------------------------------------------------------------------------
TEST("require that leaf nodes have no children") {
- EXPECT_TRUE(Function::parse("123").root().is_leaf());
- EXPECT_TRUE(Function::parse("x").root().is_leaf());
- EXPECT_TRUE(Function::parse("\"abc\"").root().is_leaf());
- EXPECT_EQUAL(0u, Function::parse("123").root().num_children());
- EXPECT_EQUAL(0u, Function::parse("x").root().num_children());
- EXPECT_EQUAL(0u, Function::parse("\"abc\"").root().num_children());
+ EXPECT_TRUE(Function::parse("123")->root().is_leaf());
+ EXPECT_TRUE(Function::parse("x")->root().is_leaf());
+ EXPECT_TRUE(Function::parse("\"abc\"")->root().is_leaf());
+ EXPECT_EQUAL(0u, Function::parse("123")->root().num_children());
+ EXPECT_EQUAL(0u, Function::parse("x")->root().num_children());
+ EXPECT_EQUAL(0u, Function::parse("\"abc\"")->root().num_children());
}
TEST("require that Neg child can be accessed") {
- Function f = Function::parse("-x");
- const Node &root = f.root();
+ auto f = Function::parse("-x");
+ const Node &root = f->root();
EXPECT_TRUE(!root.is_leaf());
ASSERT_EQUAL(1u, root.num_children());
EXPECT_TRUE(root.get_child(0).is_param());
}
TEST("require that Not child can be accessed") {
- Function f = Function::parse("!1");
- const Node &root = f.root();
+ auto f = Function::parse("!1");
+ const Node &root = f->root();
EXPECT_TRUE(!root.is_leaf());
ASSERT_EQUAL(1u, root.num_children());
EXPECT_EQUAL(1.0, root.get_child(0).get_const_value());
}
TEST("require that If children can be accessed") {
- Function f = Function::parse("if(1,2,3)");
- const Node &root = f.root();
+ auto f = Function::parse("if(1,2,3)");
+ const Node &root = f->root();
EXPECT_TRUE(!root.is_leaf());
ASSERT_EQUAL(3u, root.num_children());
EXPECT_EQUAL(1.0, root.get_child(0).get_const_value());
@@ -360,8 +360,8 @@ TEST("require that If children can be accessed") {
}
TEST("require that Operator children can be accessed") {
- Function f = Function::parse("1+2");
- const Node &root = f.root();
+ auto f = Function::parse("1+2");
+ const Node &root = f->root();
EXPECT_TRUE(!root.is_leaf());
ASSERT_EQUAL(2u, root.num_children());
EXPECT_EQUAL(1.0, root.get_child(0).get_const_value());
@@ -369,8 +369,8 @@ TEST("require that Operator children can be accessed") {
}
TEST("require that Call children can be accessed") {
- Function f = Function::parse("max(1,2)");
- const Node &root = f.root();
+ auto f = Function::parse("max(1,2)");
+ const Node &root = f->root();
EXPECT_TRUE(!root.is_leaf());
ASSERT_EQUAL(2u, root.num_children());
EXPECT_EQUAL(1.0, root.get_child(0).get_const_value());
@@ -388,8 +388,8 @@ struct MyNodeHandler : public NodeHandler {
size_t detach_from_root(const vespalib::string &expr) {
MyNodeHandler handler;
- Function function = Function::parse(expr);
- nodes::Node &mutable_root = const_cast<nodes::Node&>(function.root());
+ auto function = Function::parse(expr);
+ nodes::Node &mutable_root = const_cast<nodes::Node&>(function->root());
mutable_root.detach_children(handler);
return handler.nodes.size();
}
@@ -444,15 +444,15 @@ struct MyTraverser : public NodeTraverser {
};
size_t verify_traversal(size_t open_true_cnt, const vespalib::string &expression) {
- Function function = Function::parse(expression);
- if (!EXPECT_TRUE(!function.has_error())) {
- fprintf(stderr, "--> %s\n", function.get_error().c_str());
+ auto function = Function::parse(expression);
+ if (!EXPECT_TRUE(!function->has_error())) {
+ fprintf(stderr, "--> %s\n", function->get_error().c_str());
}
MyTraverser traverser(open_true_cnt);
- function.root().traverse(traverser);
+ function->root().traverse(traverser);
size_t offset = 0;
size_t open_cnt = open_true_cnt;
- traverser.verify(function.root(), offset, open_cnt);
+ traverser.verify(function->root(), offset, open_cnt);
EXPECT_EQUAL(offset, traverser.history.size());
return offset;
}
@@ -476,112 +476,112 @@ TEST("require that traversal works as expected") {
//-----------------------------------------------------------------------------
TEST("require that node types can be checked") {
- EXPECT_TRUE(nodes::check_type<nodes::Add>(Function::parse("1+2").root()));
- EXPECT_TRUE(!nodes::check_type<nodes::Add>(Function::parse("1-2").root()));
- EXPECT_TRUE(!nodes::check_type<nodes::Add>(Function::parse("1*2").root()));
- EXPECT_TRUE(!nodes::check_type<nodes::Add>(Function::parse("1/2").root()));
- EXPECT_TRUE((nodes::check_type<nodes::Add, nodes::Sub, nodes::Mul>(Function::parse("1+2").root())));
- EXPECT_TRUE((nodes::check_type<nodes::Add, nodes::Sub, nodes::Mul>(Function::parse("1-2").root())));
- EXPECT_TRUE((nodes::check_type<nodes::Add, nodes::Sub, nodes::Mul>(Function::parse("1*2").root())));
- EXPECT_TRUE((!nodes::check_type<nodes::Add, nodes::Sub, nodes::Mul>(Function::parse("1/2").root())));
+ EXPECT_TRUE(nodes::check_type<nodes::Add>(Function::parse("1+2")->root()));
+ EXPECT_TRUE(!nodes::check_type<nodes::Add>(Function::parse("1-2")->root()));
+ EXPECT_TRUE(!nodes::check_type<nodes::Add>(Function::parse("1*2")->root()));
+ EXPECT_TRUE(!nodes::check_type<nodes::Add>(Function::parse("1/2")->root()));
+ EXPECT_TRUE((nodes::check_type<nodes::Add, nodes::Sub, nodes::Mul>(Function::parse("1+2")->root())));
+ EXPECT_TRUE((nodes::check_type<nodes::Add, nodes::Sub, nodes::Mul>(Function::parse("1-2")->root())));
+ EXPECT_TRUE((nodes::check_type<nodes::Add, nodes::Sub, nodes::Mul>(Function::parse("1*2")->root())));
+ EXPECT_TRUE((!nodes::check_type<nodes::Add, nodes::Sub, nodes::Mul>(Function::parse("1/2")->root())));
}
//-----------------------------------------------------------------------------
TEST("require that parameter is param, but not const") {
- EXPECT_TRUE(Function::parse("x").root().is_param());
- EXPECT_TRUE(!Function::parse("x").root().is_const());
+ EXPECT_TRUE(Function::parse("x")->root().is_param());
+ EXPECT_TRUE(!Function::parse("x")->root().is_const());
}
TEST("require that inverted parameter is not param") {
- EXPECT_TRUE(!Function::parse("-x").root().is_param());
+ EXPECT_TRUE(!Function::parse("-x")->root().is_param());
}
TEST("require that number is const, but not param") {
- EXPECT_TRUE(Function::parse("123").root().is_const());
- EXPECT_TRUE(!Function::parse("123").root().is_param());
+ EXPECT_TRUE(Function::parse("123")->root().is_const());
+ EXPECT_TRUE(!Function::parse("123")->root().is_param());
}
TEST("require that string is const") {
- EXPECT_TRUE(Function::parse("\"x\"").root().is_const());
+ EXPECT_TRUE(Function::parse("\"x\"")->root().is_const());
}
TEST("require that neg is const if sub-expression is const") {
- EXPECT_TRUE(Function::parse("-123").root().is_const());
- EXPECT_TRUE(!Function::parse("-x").root().is_const());
+ EXPECT_TRUE(Function::parse("-123")->root().is_const());
+ EXPECT_TRUE(!Function::parse("-x")->root().is_const());
}
TEST("require that not is const if sub-expression is const") {
- EXPECT_TRUE(Function::parse("!1").root().is_const());
- EXPECT_TRUE(!Function::parse("!x").root().is_const());
+ EXPECT_TRUE(Function::parse("!1")->root().is_const());
+ EXPECT_TRUE(!Function::parse("!x")->root().is_const());
}
TEST("require that operators are cost if both children are const") {
- EXPECT_TRUE(!Function::parse("x+y").root().is_const());
- EXPECT_TRUE(!Function::parse("1+y").root().is_const());
- EXPECT_TRUE(!Function::parse("x+2").root().is_const());
- EXPECT_TRUE(Function::parse("1+2").root().is_const());
+ EXPECT_TRUE(!Function::parse("x+y")->root().is_const());
+ EXPECT_TRUE(!Function::parse("1+y")->root().is_const());
+ EXPECT_TRUE(!Function::parse("x+2")->root().is_const());
+ EXPECT_TRUE(Function::parse("1+2")->root().is_const());
}
TEST("require that set membership is never tagged as const (NB: avoids jit recursion)") {
- EXPECT_TRUE(!Function::parse("x in [x,y,z]").root().is_const());
- EXPECT_TRUE(!Function::parse("1 in [x,y,z]").root().is_const());
- EXPECT_TRUE(!Function::parse("1 in [1,y,z]").root().is_const());
- EXPECT_TRUE(!Function::parse("1 in [1,2,3]").root().is_const());
+ EXPECT_TRUE(!Function::parse("x in [x,y,z]")->root().is_const());
+ EXPECT_TRUE(!Function::parse("1 in [x,y,z]")->root().is_const());
+ EXPECT_TRUE(!Function::parse("1 in [1,y,z]")->root().is_const());
+ EXPECT_TRUE(!Function::parse("1 in [1,2,3]")->root().is_const());
}
TEST("require that calls are cost if all parameters are const") {
- EXPECT_TRUE(!Function::parse("max(x,y)").root().is_const());
- EXPECT_TRUE(!Function::parse("max(1,y)").root().is_const());
- EXPECT_TRUE(!Function::parse("max(x,2)").root().is_const());
- EXPECT_TRUE(Function::parse("max(1,2)").root().is_const());
+ EXPECT_TRUE(!Function::parse("max(x,y)")->root().is_const());
+ EXPECT_TRUE(!Function::parse("max(1,y)")->root().is_const());
+ EXPECT_TRUE(!Function::parse("max(x,2)")->root().is_const());
+ EXPECT_TRUE(Function::parse("max(1,2)")->root().is_const());
}
//-----------------------------------------------------------------------------
TEST("require that feature less than constant is tree if children are trees or constants") {
- EXPECT_TRUE(Function::parse("if (foo < 2, 3, 4)").root().is_tree());
- EXPECT_TRUE(Function::parse("if (foo < 2, if(bar < 3, 4, 5), 6)").root().is_tree());
- EXPECT_TRUE(Function::parse("if (foo < 2, if(bar < 3, 4, 5), if(baz < 6, 7, 8))").root().is_tree());
- EXPECT_TRUE(Function::parse("if (foo < 2, 3, if(baz < 4, 5, 6))").root().is_tree());
- EXPECT_TRUE(Function::parse("if (foo < max(1,2), 3, 4)").root().is_tree());
- EXPECT_TRUE(!Function::parse("if (2 < foo, 3, 4)").root().is_tree());
- EXPECT_TRUE(!Function::parse("if (foo < bar, 3, 4)").root().is_tree());
- EXPECT_TRUE(!Function::parse("if (1 < 2, 3, 4)").root().is_tree());
- EXPECT_TRUE(!Function::parse("if (foo <= 2, 3, 4)").root().is_tree());
- EXPECT_TRUE(!Function::parse("if (foo == 2, 3, 4)").root().is_tree());
- EXPECT_TRUE(!Function::parse("if (foo > 2, 3, 4)").root().is_tree());
- EXPECT_TRUE(!Function::parse("if (foo >= 2, 3, 4)").root().is_tree());
- EXPECT_TRUE(!Function::parse("if (foo ~= 2, 3, 4)").root().is_tree());
+ EXPECT_TRUE(Function::parse("if (foo < 2, 3, 4)")->root().is_tree());
+ EXPECT_TRUE(Function::parse("if (foo < 2, if(bar < 3, 4, 5), 6)")->root().is_tree());
+ EXPECT_TRUE(Function::parse("if (foo < 2, if(bar < 3, 4, 5), if(baz < 6, 7, 8))")->root().is_tree());
+ EXPECT_TRUE(Function::parse("if (foo < 2, 3, if(baz < 4, 5, 6))")->root().is_tree());
+ EXPECT_TRUE(Function::parse("if (foo < max(1,2), 3, 4)")->root().is_tree());
+ EXPECT_TRUE(!Function::parse("if (2 < foo, 3, 4)")->root().is_tree());
+ EXPECT_TRUE(!Function::parse("if (foo < bar, 3, 4)")->root().is_tree());
+ EXPECT_TRUE(!Function::parse("if (1 < 2, 3, 4)")->root().is_tree());
+ EXPECT_TRUE(!Function::parse("if (foo <= 2, 3, 4)")->root().is_tree());
+ EXPECT_TRUE(!Function::parse("if (foo == 2, 3, 4)")->root().is_tree());
+ EXPECT_TRUE(!Function::parse("if (foo > 2, 3, 4)")->root().is_tree());
+ EXPECT_TRUE(!Function::parse("if (foo >= 2, 3, 4)")->root().is_tree());
+ EXPECT_TRUE(!Function::parse("if (foo ~= 2, 3, 4)")->root().is_tree());
}
TEST("require that feature in set of constants is tree if children are trees or constants") {
- EXPECT_TRUE(Function::parse("if (foo in [1, 2], 3, 4)").root().is_tree());
- EXPECT_TRUE(Function::parse("if (foo in [1, 2], if(bar < 3, 4, 5), 6)").root().is_tree());
- EXPECT_TRUE(Function::parse("if (foo in [1, 2], if(bar < 3, 4, 5), if(baz < 6, 7, 8))").root().is_tree());
- EXPECT_TRUE(Function::parse("if (foo in [1, 2], 3, if(baz < 4, 5, 6))").root().is_tree());
- EXPECT_TRUE(Function::parse("if (foo in [1, 2], min(1,3), max(1,4))").root().is_tree());
- EXPECT_TRUE(!Function::parse("if (1 in [1, 2], 3, 4)").root().is_tree());
+ EXPECT_TRUE(Function::parse("if (foo in [1, 2], 3, 4)")->root().is_tree());
+ EXPECT_TRUE(Function::parse("if (foo in [1, 2], if(bar < 3, 4, 5), 6)")->root().is_tree());
+ EXPECT_TRUE(Function::parse("if (foo in [1, 2], if(bar < 3, 4, 5), if(baz < 6, 7, 8))")->root().is_tree());
+ EXPECT_TRUE(Function::parse("if (foo in [1, 2], 3, if(baz < 4, 5, 6))")->root().is_tree());
+ EXPECT_TRUE(Function::parse("if (foo in [1, 2], min(1,3), max(1,4))")->root().is_tree());
+ EXPECT_TRUE(!Function::parse("if (1 in [1, 2], 3, 4)")->root().is_tree());
}
TEST("require that sums of trees and forests are forests") {
- EXPECT_TRUE(Function::parse("if(foo<1,2,3) + if(bar<4,5,6)").root().is_forest());
- EXPECT_TRUE(Function::parse("if(foo<1,2,3) + if(bar<4,5,6) + if(bar<7,8,9)").root().is_forest());
- EXPECT_TRUE(!Function::parse("if(foo<1,2,3)").root().is_forest());
- EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + 10").root().is_forest());
- EXPECT_TRUE(!Function::parse("10 + if(bar<4,5,6)").root().is_forest());
- EXPECT_TRUE(!Function::parse("if(foo<1,2,3) - if(bar<4,5,6)").root().is_forest());
- EXPECT_TRUE(!Function::parse("if(foo<1,2,3) * if(bar<4,5,6)").root().is_forest());
- EXPECT_TRUE(!Function::parse("if(foo<1,2,3) / if(bar<4,5,6)").root().is_forest());
- EXPECT_TRUE(!Function::parse("if(foo<1,2,3) ^ if(bar<4,5,6)").root().is_forest());
- EXPECT_TRUE(!Function::parse("if(foo<1,2,3) - if(bar<4,5,6) + if(bar<7,8,9)").root().is_forest());
- EXPECT_TRUE(!Function::parse("if(foo<1,2,3) * if(bar<4,5,6) + if(bar<7,8,9)").root().is_forest());
- EXPECT_TRUE(!Function::parse("if(foo<1,2,3) / if(bar<4,5,6) + if(bar<7,8,9)").root().is_forest());
- EXPECT_TRUE(!Function::parse("if(foo<1,2,3) ^ if(bar<4,5,6) + if(bar<7,8,9)").root().is_forest());
- EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + if(bar<4,5,6) - if(bar<7,8,9)").root().is_forest());
- EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + if(bar<4,5,6) * if(bar<7,8,9)").root().is_forest());
- EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + if(bar<4,5,6) / if(bar<7,8,9)").root().is_forest());
- EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + if(bar<4,5,6) ^ if(bar<7,8,9)").root().is_forest());
+ EXPECT_TRUE(Function::parse("if(foo<1,2,3) + if(bar<4,5,6)")->root().is_forest());
+ EXPECT_TRUE(Function::parse("if(foo<1,2,3) + if(bar<4,5,6) + if(bar<7,8,9)")->root().is_forest());
+ EXPECT_TRUE(!Function::parse("if(foo<1,2,3)")->root().is_forest());
+ EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + 10")->root().is_forest());
+ EXPECT_TRUE(!Function::parse("10 + if(bar<4,5,6)")->root().is_forest());
+ EXPECT_TRUE(!Function::parse("if(foo<1,2,3) - if(bar<4,5,6)")->root().is_forest());
+ EXPECT_TRUE(!Function::parse("if(foo<1,2,3) * if(bar<4,5,6)")->root().is_forest());
+ EXPECT_TRUE(!Function::parse("if(foo<1,2,3) / if(bar<4,5,6)")->root().is_forest());
+ EXPECT_TRUE(!Function::parse("if(foo<1,2,3) ^ if(bar<4,5,6)")->root().is_forest());
+ EXPECT_TRUE(!Function::parse("if(foo<1,2,3) - if(bar<4,5,6) + if(bar<7,8,9)")->root().is_forest());
+ EXPECT_TRUE(!Function::parse("if(foo<1,2,3) * if(bar<4,5,6) + if(bar<7,8,9)")->root().is_forest());
+ EXPECT_TRUE(!Function::parse("if(foo<1,2,3) / if(bar<4,5,6) + if(bar<7,8,9)")->root().is_forest());
+ EXPECT_TRUE(!Function::parse("if(foo<1,2,3) ^ if(bar<4,5,6) + if(bar<7,8,9)")->root().is_forest());
+ EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + if(bar<4,5,6) - if(bar<7,8,9)")->root().is_forest());
+ EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + if(bar<4,5,6) * if(bar<7,8,9)")->root().is_forest());
+ EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + if(bar<4,5,6) / if(bar<7,8,9)")->root().is_forest());
+ EXPECT_TRUE(!Function::parse("if(foo<1,2,3) + if(bar<4,5,6) ^ if(bar<7,8,9)")->root().is_forest());
}
//-----------------------------------------------------------------------------
@@ -657,53 +657,53 @@ struct MySymbolExtractor : SymbolExtractor {
};
TEST("require that custom symbol extractor may be used") {
- EXPECT_EQUAL("[x+]...[missing value]...[*y]", Function::parse(params, "x+*y").dump());
- EXPECT_EQUAL("[x+]...[missing value]...[*y]", Function::parse(params, "x+*y", MySymbolExtractor()).dump());
- EXPECT_EQUAL("[x+]...[unknown symbol: 'x+']...[*y]", Function::parse(params, "x+*y", MySymbolExtractor({'+'})).dump());
- EXPECT_EQUAL("[x+*y]...[unknown symbol: 'x+*y']...[]", Function::parse(params, "x+*y", MySymbolExtractor({'+', '*'})).dump());
+ EXPECT_EQUAL("[x+]...[missing value]...[*y]", Function::parse(params, "x+*y")->dump());
+ EXPECT_EQUAL("[x+]...[missing value]...[*y]", Function::parse(params, "x+*y", MySymbolExtractor())->dump());
+ EXPECT_EQUAL("[x+]...[unknown symbol: 'x+']...[*y]", Function::parse(params, "x+*y", MySymbolExtractor({'+'}))->dump());
+ EXPECT_EQUAL("[x+*y]...[unknown symbol: 'x+*y']...[]", Function::parse(params, "x+*y", MySymbolExtractor({'+', '*'}))->dump());
}
TEST("require that unknown function works as expected with custom symbol extractor") {
- EXPECT_EQUAL("[bogus(]...[unknown function: 'bogus']...[x)+y]", Function::parse(params, "bogus(x)+y").dump());
- EXPECT_EQUAL("[bogus]...[unknown symbol: 'bogus']...[(x)+y]", Function::parse(params, "bogus(x)+y", MySymbolExtractor()).dump());
- EXPECT_EQUAL("[bogus(x)]...[unknown symbol: 'bogus(x)']...[+y]", Function::parse(params, "bogus(x)+y", MySymbolExtractor({'(', ')'})).dump());
+ EXPECT_EQUAL("[bogus(]...[unknown function: 'bogus']...[x)+y]", Function::parse(params, "bogus(x)+y")->dump());
+ EXPECT_EQUAL("[bogus]...[unknown symbol: 'bogus']...[(x)+y]", Function::parse(params, "bogus(x)+y", MySymbolExtractor())->dump());
+ EXPECT_EQUAL("[bogus(x)]...[unknown symbol: 'bogus(x)']...[+y]", Function::parse(params, "bogus(x)+y", MySymbolExtractor({'(', ')'}))->dump());
}
TEST("require that unknown function that is valid parameter works as expected with custom symbol extractor") {
- EXPECT_EQUAL("[z(]...[unknown function: 'z']...[x)+y]", Function::parse(params, "z(x)+y").dump());
- EXPECT_EQUAL("[z]...[invalid operator: '(']...[(x)+y]", Function::parse(params, "z(x)+y", MySymbolExtractor()).dump());
- EXPECT_EQUAL("[z(x)]...[unknown symbol: 'z(x)']...[+y]", Function::parse(params, "z(x)+y", MySymbolExtractor({'(', ')'})).dump());
+ EXPECT_EQUAL("[z(]...[unknown function: 'z']...[x)+y]", Function::parse(params, "z(x)+y")->dump());
+ EXPECT_EQUAL("[z]...[invalid operator: '(']...[(x)+y]", Function::parse(params, "z(x)+y", MySymbolExtractor())->dump());
+ EXPECT_EQUAL("[z(x)]...[unknown symbol: 'z(x)']...[+y]", Function::parse(params, "z(x)+y", MySymbolExtractor({'(', ')'}))->dump());
}
TEST("require that custom symbol extractor is not invoked for known function call") {
MySymbolExtractor extractor;
EXPECT_EQUAL(extractor.invoke_count, 0u);
- EXPECT_EQUAL("[bogus]...[unknown symbol: 'bogus']...[(1,2)]", Function::parse(params, "bogus(1,2)", extractor).dump());
+ EXPECT_EQUAL("[bogus]...[unknown symbol: 'bogus']...[(1,2)]", Function::parse(params, "bogus(1,2)", extractor)->dump());
EXPECT_EQUAL(extractor.invoke_count, 1u);
- EXPECT_EQUAL("max(1,2)", Function::parse(params, "max(1,2)", extractor).dump());
+ EXPECT_EQUAL("max(1,2)", Function::parse(params, "max(1,2)", extractor)->dump());
EXPECT_EQUAL(extractor.invoke_count, 1u);
}
//-----------------------------------------------------------------------------
TEST("require that valid function does not report parse error") {
- Function function = Function::parse(params, "x + y");
- EXPECT_TRUE(!function.has_error());
- EXPECT_EQUAL("", function.get_error());
+ auto function = Function::parse(params, "x + y");
+ EXPECT_TRUE(!function->has_error());
+ EXPECT_EQUAL("", function->get_error());
}
TEST("require that an invalid function with explicit paramers retain its parameters") {
- Function function = Function::parse({"x", "y"}, "x & y");
- EXPECT_TRUE(function.has_error());
- ASSERT_EQUAL(2u, function.num_params());
- ASSERT_EQUAL("x", function.param_name(0));
- ASSERT_EQUAL("y", function.param_name(1));
+ auto function = Function::parse({"x", "y"}, "x & y");
+ EXPECT_TRUE(function->has_error());
+ ASSERT_EQUAL(2u, function->num_params());
+ ASSERT_EQUAL("x", function->param_name(0));
+ ASSERT_EQUAL("y", function->param_name(1));
}
TEST("require that an invalid function with implicit paramers has no parameters") {
- Function function = Function::parse("x & y");
- EXPECT_TRUE(function.has_error());
- EXPECT_EQUAL(0u, function.num_params());
+ auto function = Function::parse("x & y");
+ EXPECT_TRUE(function->has_error());
+ EXPECT_EQUAL(0u, function->num_params());
}
TEST("require that unknown operator gives parse error") {
@@ -725,23 +725,23 @@ TEST("require that missing value gives parse error") {
TEST("require that tensor operations can be nested") {
EXPECT_EQUAL("reduce(reduce(reduce(a,sum),sum),sum,dim)",
- Function::parse("reduce(reduce(reduce(a,sum),sum),sum,dim)").dump());
+ Function::parse("reduce(reduce(reduce(a,sum),sum),sum,dim)")->dump());
}
//-----------------------------------------------------------------------------
TEST("require that tensor map can be parsed") {
- EXPECT_EQUAL("map(a,f(x)(x+1))", Function::parse("map(a,f(x)(x+1))").dump());
- EXPECT_EQUAL("map(a,f(x)(x+1))", Function::parse(" map ( a , f ( x ) ( x + 1 ) ) ").dump());
+ EXPECT_EQUAL("map(a,f(x)(x+1))", Function::parse("map(a,f(x)(x+1))")->dump());
+ EXPECT_EQUAL("map(a,f(x)(x+1))", Function::parse(" map ( a , f ( x ) ( x + 1 ) ) ")->dump());
}
TEST("require that tensor join can be parsed") {
- EXPECT_EQUAL("join(a,b,f(x,y)(x+y))", Function::parse("join(a,b,f(x,y)(x+y))").dump());
- EXPECT_EQUAL("join(a,b,f(x,y)(x+y))", Function::parse(" join ( a , b , f ( x , y ) ( x + y ) ) ").dump());
+ EXPECT_EQUAL("join(a,b,f(x,y)(x+y))", Function::parse("join(a,b,f(x,y)(x+y))")->dump());
+ EXPECT_EQUAL("join(a,b,f(x,y)(x+y))", Function::parse(" join ( a , b , f ( x , y ) ( x + y ) ) ")->dump());
}
TEST("require that parenthesis are added around lambda expression when needed") {
- EXPECT_EQUAL("f(x)(sin(x))", Function::parse("sin(x)").dump_as_lambda());
+ EXPECT_EQUAL("f(x)(sin(x))", Function::parse("sin(x)")->dump_as_lambda());
}
TEST("require that parse error inside a lambda fails the enclosing expression") {
@@ -755,16 +755,16 @@ TEST("require that outer parameters are hidden within a lambda") {
//-----------------------------------------------------------------------------
TEST("require that tensor reduce can be parsed") {
- EXPECT_EQUAL("reduce(x,sum,a,b)", Function::parse({"x"}, "reduce(x,sum,a,b)").dump());
- EXPECT_EQUAL("reduce(x,sum,a,b,c)", Function::parse({"x"}, "reduce(x,sum,a,b,c)").dump());
- EXPECT_EQUAL("reduce(x,sum,a,b,c)", Function::parse({"x"}, " reduce ( x , sum , a , b , c ) ").dump());
- EXPECT_EQUAL("reduce(x,sum)", Function::parse({"x"}, "reduce(x,sum)").dump());
- EXPECT_EQUAL("reduce(x,avg)", Function::parse({"x"}, "reduce(x,avg)").dump());
- EXPECT_EQUAL("reduce(x,avg)", Function::parse({"x"}, "reduce( x , avg )").dump());
- EXPECT_EQUAL("reduce(x,count)", Function::parse({"x"}, "reduce(x,count)").dump());
- EXPECT_EQUAL("reduce(x,prod)", Function::parse({"x"}, "reduce(x,prod)").dump());
- EXPECT_EQUAL("reduce(x,min)", Function::parse({"x"}, "reduce(x,min)").dump());
- EXPECT_EQUAL("reduce(x,max)", Function::parse({"x"}, "reduce(x,max)").dump());
+ EXPECT_EQUAL("reduce(x,sum,a,b)", Function::parse({"x"}, "reduce(x,sum,a,b)")->dump());
+ EXPECT_EQUAL("reduce(x,sum,a,b,c)", Function::parse({"x"}, "reduce(x,sum,a,b,c)")->dump());
+ EXPECT_EQUAL("reduce(x,sum,a,b,c)", Function::parse({"x"}, " reduce ( x , sum , a , b , c ) ")->dump());
+ EXPECT_EQUAL("reduce(x,sum)", Function::parse({"x"}, "reduce(x,sum)")->dump());
+ EXPECT_EQUAL("reduce(x,avg)", Function::parse({"x"}, "reduce(x,avg)")->dump());
+ EXPECT_EQUAL("reduce(x,avg)", Function::parse({"x"}, "reduce( x , avg )")->dump());
+ EXPECT_EQUAL("reduce(x,count)", Function::parse({"x"}, "reduce(x,count)")->dump());
+ EXPECT_EQUAL("reduce(x,prod)", Function::parse({"x"}, "reduce(x,prod)")->dump());
+ EXPECT_EQUAL("reduce(x,min)", Function::parse({"x"}, "reduce(x,min)")->dump());
+ EXPECT_EQUAL("reduce(x,max)", Function::parse({"x"}, "reduce(x,max)")->dump());
}
TEST("require that tensor reduce with unknown aggregator fails") {
@@ -778,14 +778,14 @@ TEST("require that tensor reduce with duplicate dimensions fails") {
//-----------------------------------------------------------------------------
TEST("require that tensor rename can be parsed") {
- EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,a,b)").dump());
- EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,(a),(b))").dump());
- EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,a,(b))").dump());
- EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,(a),b)").dump());
- EXPECT_EQUAL("rename(x,(a,b),(b,a))", Function::parse({"x"}, "rename(x,(a,b),(b,a))").dump());
- EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename( x , a , b )").dump());
- EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename( x , ( a ) , ( b ) )").dump());
- EXPECT_EQUAL("rename(x,(a,b),(b,a))", Function::parse({"x"}, "rename( x , ( a , b ) , ( b , a ) )").dump());
+ EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,a,b)")->dump());
+ EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,(a),(b))")->dump());
+ EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,a,(b))")->dump());
+ EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,(a),b)")->dump());
+ EXPECT_EQUAL("rename(x,(a,b),(b,a))", Function::parse({"x"}, "rename(x,(a,b),(b,a))")->dump());
+ EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename( x , a , b )")->dump());
+ EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename( x , ( a ) , ( b ) )")->dump());
+ EXPECT_EQUAL("rename(x,(a,b),(b,a))", Function::parse({"x"}, "rename( x , ( a , b ) , ( b , a ) )")->dump());
}
TEST("require that tensor rename dimension lists cannot be empty") {
@@ -808,9 +808,9 @@ TEST("require that tensor rename dimension lists must have equal size") {
//-----------------------------------------------------------------------------
TEST("require that tensor lambda can be parsed") {
- EXPECT_EQUAL("tensor(x[3]):{{x:0}:0,{x:1}:1,{x:2}:2}", Function::parse({}, "tensor(x[3])(x)").dump());
+ EXPECT_EQUAL("tensor(x[3]):{{x:0}:0,{x:1}:1,{x:2}:2}", Function::parse({}, "tensor(x[3])(x)")->dump());
EXPECT_EQUAL("tensor(x[2],y[2]):{{x:0,y:0}:(0==0),{x:0,y:1}:(0==1),{x:1,y:0}:(1==0),{x:1,y:1}:(1==1)}",
- Function::parse({}, " tensor ( x [ 2 ] , y [ 2 ] ) ( x == y ) ").dump());
+ Function::parse({}, " tensor ( x [ 2 ] , y [ 2 ] ) ( x == y ) ")->dump());
}
TEST("require that tensor lambda requires appropriate tensor type") {
@@ -821,7 +821,7 @@ TEST("require that tensor lambda requires appropriate tensor type") {
TEST("require that tensor lambda can use non-dimension symbols") {
EXPECT_EQUAL("tensor(x[2]):{{x:0}:(0==a),{x:1}:(1==a)}",
- Function::parse({"a"}, "tensor(x[2])(x==a)").dump());
+ Function::parse({"a"}, "tensor(x[2])(x==a)")->dump());
}
//-----------------------------------------------------------------------------
@@ -832,24 +832,24 @@ TEST("require that verbose tensor create can be parsed") {
auto sparse2 = Function::parse("tensor(x{}):{{x:\"a\"}:1,{x:\"b\"}:2,{x:\"c\"}:3}");
auto mixed1 = Function::parse("tensor(x{},y[2]):{{x:a,y:0}:1,{x:a,y:1}:2}");
auto mixed2 = Function::parse("tensor(x{},y[2]):{{x:\"a\",y:0}:1,{x:\"a\",y:1}:2}");
- EXPECT_EQUAL("tensor(x[3]):{{x:0}:1,{x:1}:2,{x:2}:3}", dense.dump());
- EXPECT_EQUAL("tensor(x{}):{{x:\"a\"}:1,{x:\"b\"}:2,{x:\"c\"}:3}", sparse1.dump());
- EXPECT_EQUAL("tensor(x{}):{{x:\"a\"}:1,{x:\"b\"}:2,{x:\"c\"}:3}", sparse2.dump());
- EXPECT_EQUAL("tensor(x{},y[2]):{{x:\"a\",y:0}:1,{x:\"a\",y:1}:2}", mixed1.dump());
- EXPECT_EQUAL("tensor(x{},y[2]):{{x:\"a\",y:0}:1,{x:\"a\",y:1}:2}", mixed2.dump());
+ EXPECT_EQUAL("tensor(x[3]):{{x:0}:1,{x:1}:2,{x:2}:3}", dense->dump());
+ EXPECT_EQUAL("tensor(x{}):{{x:\"a\"}:1,{x:\"b\"}:2,{x:\"c\"}:3}", sparse1->dump());
+ EXPECT_EQUAL("tensor(x{}):{{x:\"a\"}:1,{x:\"b\"}:2,{x:\"c\"}:3}", sparse2->dump());
+ EXPECT_EQUAL("tensor(x{},y[2]):{{x:\"a\",y:0}:1,{x:\"a\",y:1}:2}", mixed1->dump());
+ EXPECT_EQUAL("tensor(x{},y[2]):{{x:\"a\",y:0}:1,{x:\"a\",y:1}:2}", mixed2->dump());
}
TEST("require that verbose tensor create can contain expressions") {
auto fun = Function::parse("tensor(x[2]):{{x:0}:1,{x:1}:2+a}");
- EXPECT_EQUAL("tensor(x[2]):{{x:0}:1,{x:1}:(2+a)}", fun.dump());
- ASSERT_EQUAL(fun.num_params(), 1u);
- EXPECT_EQUAL(fun.param_name(0), "a");
+ EXPECT_EQUAL("tensor(x[2]):{{x:0}:1,{x:1}:(2+a)}", fun->dump());
+ ASSERT_EQUAL(fun->num_params(), 1u);
+ EXPECT_EQUAL(fun->param_name(0), "a");
}
TEST("require that verbose tensor create handles spaces and reordering of various elements") {
auto fun = Function::parse(" tensor ( y [ 2 ] , x [ 2 ] ) : { { x : 0 , y : 1 } : 2 , "
"{ y : 0 , x : 0 } : 1 , { y : 0 , x : 1 } : 3 , { x : 1 , y : 1 } : 4 } ");
- EXPECT_EQUAL("tensor(x[2],y[2]):{{x:0,y:0}:1,{x:0,y:1}:2,{x:1,y:0}:3,{x:1,y:1}:4}", fun.dump());
+ EXPECT_EQUAL("tensor(x[2],y[2]):{{x:0,y:0}:1,{x:0,y:1}:2,{x:1,y:0}:3,{x:1,y:1}:4}", fun->dump());
}
TEST("require that verbose tensor create detects invalid tensor type") {
@@ -890,23 +890,23 @@ TEST("require that convenient tensor create can be parsed") {
auto sparse2 = Function::parse("tensor(x{}):{\"a\":1,\"b\":2,\"c\":3}");
auto mixed1 = Function::parse("tensor(x{},y[2]):{a:[1,2]}");
auto mixed2 = Function::parse("tensor(x{},y[2]):{\"a\":[1,2]}");
- EXPECT_EQUAL("tensor(x[3]):{{x:0}:1,{x:1}:2,{x:2}:3}", dense.dump());
- EXPECT_EQUAL("tensor(x{}):{{x:\"a\"}:1,{x:\"b\"}:2,{x:\"c\"}:3}", sparse1.dump());
- EXPECT_EQUAL("tensor(x{}):{{x:\"a\"}:1,{x:\"b\"}:2,{x:\"c\"}:3}", sparse2.dump());
- EXPECT_EQUAL("tensor(x{},y[2]):{{x:\"a\",y:0}:1,{x:\"a\",y:1}:2}", mixed1.dump());
- EXPECT_EQUAL("tensor(x{},y[2]):{{x:\"a\",y:0}:1,{x:\"a\",y:1}:2}", mixed2.dump());
+ EXPECT_EQUAL("tensor(x[3]):{{x:0}:1,{x:1}:2,{x:2}:3}", dense->dump());
+ EXPECT_EQUAL("tensor(x{}):{{x:\"a\"}:1,{x:\"b\"}:2,{x:\"c\"}:3}", sparse1->dump());
+ EXPECT_EQUAL("tensor(x{}):{{x:\"a\"}:1,{x:\"b\"}:2,{x:\"c\"}:3}", sparse2->dump());
+ EXPECT_EQUAL("tensor(x{},y[2]):{{x:\"a\",y:0}:1,{x:\"a\",y:1}:2}", mixed1->dump());
+ EXPECT_EQUAL("tensor(x{},y[2]):{{x:\"a\",y:0}:1,{x:\"a\",y:1}:2}", mixed2->dump());
}
TEST("require that convenient tensor create can contain expressions") {
auto fun = Function::parse("tensor(x[2]):[1,2+a]");
- EXPECT_EQUAL("tensor(x[2]):{{x:0}:1,{x:1}:(2+a)}", fun.dump());
- ASSERT_EQUAL(fun.num_params(), 1u);
- EXPECT_EQUAL(fun.param_name(0), "a");
+ EXPECT_EQUAL("tensor(x[2]):{{x:0}:1,{x:1}:(2+a)}", fun->dump());
+ ASSERT_EQUAL(fun->num_params(), 1u);
+ EXPECT_EQUAL(fun->param_name(0), "a");
}
TEST("require that convenient tensor create handles dimension order") {
auto mixed = Function::parse("tensor(y{},x[2]):{a:[1,2]}");
- EXPECT_EQUAL("tensor(x[2],y{}):{{x:0,y:\"a\"}:1,{x:1,y:\"a\"}:2}", mixed.dump());
+ EXPECT_EQUAL("tensor(x[2],y{}):{{x:0,y:\"a\"}:1,{x:1,y:\"a\"}:2}", mixed->dump());
}
TEST("require that convenient tensor create can be highly nested") {
@@ -914,9 +914,9 @@ TEST("require that convenient tensor create can be highly nested") {
auto nested1 = Function::parse("tensor(a{},b{},c[1],d[1]):{x:{y:[[5]]}}");
auto nested2 = Function::parse("tensor(c[1],d[1],a{},b{}):[[{x:{y:5}}]]");
auto nested3 = Function::parse("tensor(a{},c[1],b{},d[1]): { x : [ { y : [ 5 ] } ] } ");
- EXPECT_EQUAL(expect, nested1.dump());
- EXPECT_EQUAL(expect, nested2.dump());
- EXPECT_EQUAL(expect, nested3.dump());
+ EXPECT_EQUAL(expect, nested1->dump());
+ EXPECT_EQUAL(expect, nested2->dump());
+ EXPECT_EQUAL(expect, nested3->dump());
}
TEST("require that convenient tensor create can have multiple values on multiple levels") {
@@ -925,15 +925,15 @@ TEST("require that convenient tensor create can have multiple values on multiple
auto fun2 = Function::parse("tensor(y[2],x{}):[{a:1,b:3},{a:2,b:4}]");
auto fun3 = Function::parse("tensor(x{},y[2]): { a : [ 1 , 2 ] , b : [ 3 , 4 ] } ");
auto fun4 = Function::parse("tensor(y[2],x{}): [ { a : 1 , b : 3 } , { a : 2 , b : 4 } ] ");
- EXPECT_EQUAL(expect, fun1.dump());
- EXPECT_EQUAL(expect, fun2.dump());
- EXPECT_EQUAL(expect, fun3.dump());
- EXPECT_EQUAL(expect, fun4.dump());
+ EXPECT_EQUAL(expect, fun1->dump());
+ EXPECT_EQUAL(expect, fun2->dump());
+ EXPECT_EQUAL(expect, fun3->dump());
+ EXPECT_EQUAL(expect, fun4->dump());
}
TEST("require that convenient tensor create allows under-specified tensors") {
auto fun = Function::parse("tensor(x[2],y[2]):[[],[5]]");
- EXPECT_EQUAL("tensor(x[2],y[2]):{{x:1,y:0}:5}", fun.dump());
+ EXPECT_EQUAL("tensor(x[2],y[2]):{{x:1,y:0}:5}", fun->dump());
}
TEST("require that convenient tensor create detects invalid tensor type") {
@@ -992,16 +992,16 @@ TEST("require that tensor peek empty label is not allowed") {
TEST("require that nested tensor lambda using tensor peek can be parsed") {
vespalib::string expect("tensor(x[2]):{{x:0}:tensor(y[2]):{{y:0}:((0+0)+a),{y:1}:((0+1)+a)}{y:\"0\"},"
"{x:1}:tensor(y[2]):{{y:0}:((1+0)+a),{y:1}:((1+1)+a)}{y:\"1\"}}");
- EXPECT_EQUAL(Function::parse(expect).dump(), expect);
+ EXPECT_EQUAL(Function::parse(expect)->dump(), expect);
auto fun = Function::parse("tensor(x[2])(tensor(y[2])(x+y+a){y:(x)})");
- EXPECT_EQUAL(fun.dump(), expect);
+ EXPECT_EQUAL(fun->dump(), expect);
}
//-----------------------------------------------------------------------------
TEST("require that tensor concat can be parsed") {
- EXPECT_EQUAL("concat(a,b,d)", Function::parse({"a", "b"}, "concat(a,b,d)").dump());
- EXPECT_EQUAL("concat(a,b,d)", Function::parse({"a", "b"}, " concat ( a , b , d ) ").dump());
+ EXPECT_EQUAL("concat(a,b,d)", Function::parse({"a", "b"}, "concat(a,b,d)")->dump());
+ EXPECT_EQUAL("concat(a,b,d)", Function::parse({"a", "b"}, " concat ( a , b , d ) ")->dump());
}
//-----------------------------------------------------------------------------
@@ -1012,10 +1012,10 @@ struct CheckExpressions : test::EvalSpec::EvalTest {
virtual void next_expression(const std::vector<vespalib::string> &param_names,
const vespalib::string &expression) override
{
- Function function = Function::parse(param_names, expression);
- if (function.has_error()) {
+ auto function = Function::parse(param_names, expression);
+ if (function->has_error()) {
failed = true;
- fprintf(stderr, "parse error: %s\n", function.get_error().c_str());
+ fprintf(stderr, "parse error: %s\n", function->get_error().c_str());
}
++seen_cnt;
}
diff --git a/eval/src/tests/eval/function_speed/function_speed_test.cpp b/eval/src/tests/eval/function_speed/function_speed_test.cpp
index 178ab32d734..1295f482f76 100644
--- a/eval/src/tests/eval/function_speed/function_speed_test.cpp
+++ b/eval/src/tests/eval/function_speed/function_speed_test.cpp
@@ -39,7 +39,7 @@ double big_gcc_function(double p, double o, double q, double f, double w) {
//-----------------------------------------------------------------------------
struct Fixture {
- Function function;
+ std::shared_ptr<Function const> function;
InterpretedFunction interpreted_simple;
InterpretedFunction interpreted;
CompiledFunction separate;
@@ -47,12 +47,12 @@ struct Fixture {
CompiledFunction lazy;
Fixture(const vespalib::string &expr)
: function(Function::parse(expr)),
- interpreted_simple(SimpleTensorEngine::ref(), function, NodeTypes()),
- interpreted(DefaultTensorEngine::ref(), function,
- NodeTypes(function, std::vector<ValueType>(function.num_params(), ValueType::double_type()))),
- separate(function, PassParams::SEPARATE),
- array(function, PassParams::ARRAY),
- lazy(function, PassParams::LAZY) {}
+ interpreted_simple(SimpleTensorEngine::ref(), *function, NodeTypes()),
+ interpreted(DefaultTensorEngine::ref(), *function,
+ NodeTypes(*function, std::vector<ValueType>(function->num_params(), ValueType::double_type()))),
+ separate(*function, PassParams::SEPARATE),
+ array(*function, PassParams::ARRAY),
+ lazy(*function, PassParams::LAZY) {}
};
//-----------------------------------------------------------------------------
diff --git a/eval/src/tests/eval/gbdt/fast_forest_bench.cpp b/eval/src/tests/eval/gbdt/fast_forest_bench.cpp
index 76a56bec50c..f63fc428f64 100644
--- a/eval/src/tests/eval/gbdt/fast_forest_bench.cpp
+++ b/eval/src/tests/eval/gbdt/fast_forest_bench.cpp
@@ -31,17 +31,17 @@ void run_fast_forest_bench() {
for (size_t invert_percent: std::vector<size_t>({50})) {
fprintf(stderr, "\n=== features: %zu, num leafs: %zu, num trees: %zu\n", max_features, tree_size, num_trees);
vespalib::string expression = Model().max_features(max_features).less_percent(less_percent).invert_percent(invert_percent).make_forest(num_trees, tree_size);
- Function function = Function::parse(expression);
+ auto function = Function::parse(expression);
for (size_t min_bits = std::max(size_t(8), tree_size); true; min_bits *= 2) {
- auto forest = FastForest::try_convert(function, min_bits, 64);
+ auto forest = FastForest::try_convert(*function, min_bits, 64);
if (forest) {
- estimate_cost(function.num_params(), forest->impl_name().c_str(), *forest);
+ estimate_cost(function->num_params(), forest->impl_name().c_str(), *forest);
}
if (min_bits > 64) {
break;
}
}
- estimate_cost(function.num_params(), "vm forest", CompiledFunction(function, PassParams::ARRAY, VMForest::optimize_chain));
+ estimate_cost(function->num_params(), "vm forest", CompiledFunction(*function, PassParams::ARRAY, VMForest::optimize_chain));
}
}
}
diff --git a/eval/src/tests/eval/gbdt/gbdt_benchmark.cpp b/eval/src/tests/eval/gbdt/gbdt_benchmark.cpp
index 20e04c9593e..8c34d6f7d7f 100644
--- a/eval/src/tests/eval/gbdt/gbdt_benchmark.cpp
+++ b/eval/src/tests/eval/gbdt/gbdt_benchmark.cpp
@@ -163,10 +163,10 @@ std::vector<Option> find_order(const ForestParams &params,
size_t num_trees)
{
std::vector<Result> results;
- Function forest = make_forest(params, num_trees);
+ auto forest = make_forest(params, num_trees);
for (size_t i = 0; i < options.size(); ++i) {
- CompiledFunction compiled_function = options[i].compile(forest);
- CompiledFunction compiled_function_lazy = options[i].compile_lazy(forest);
+ CompiledFunction compiled_function = options[i].compile(*forest);
+ CompiledFunction compiled_function_lazy = options[i].compile_lazy(*forest);
std::vector<double> inputs(compiled_function.num_params(), 0.5);
results.push_back({compiled_function.estimate_cost_us(inputs, budget), i});
double lazy_time = compiled_function_lazy.estimate_cost_us(inputs, budget);
@@ -184,7 +184,7 @@ std::vector<Option> find_order(const ForestParams &params,
}
double expected_path(const ForestParams &params, size_t num_trees) {
- return ForestStats(extract_trees(make_forest(params, num_trees).root())).total_expected_path_length;
+ return ForestStats(extract_trees(make_forest(params, num_trees)->root())).total_expected_path_length;
}
void explore_segment(const ForestParams &params,
diff --git a/eval/src/tests/eval/gbdt/gbdt_test.cpp b/eval/src/tests/eval/gbdt/gbdt_test.cpp
index adb3d22847a..e12ea5b9cf3 100644
--- a/eval/src/tests/eval/gbdt/gbdt_test.cpp
+++ b/eval/src/tests/eval/gbdt/gbdt_test.cpp
@@ -54,10 +54,10 @@ double eval_ff(const FastForest &ff, FastForest::Context &ctx, const std::vector
TEST("require that tree stats can be calculated") {
for (size_t tree_size = 2; tree_size < 64; ++tree_size) {
- EXPECT_EQUAL(tree_size, TreeStats(Function::parse(Model().make_tree(tree_size)).root()).size);
+ EXPECT_EQUAL(tree_size, TreeStats(Function::parse(Model().make_tree(tree_size))->root()).size);
}
- TreeStats stats1(Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))").root());
+ TreeStats stats1(Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))")->root());
EXPECT_EQUAL(3u, stats1.num_params);
EXPECT_EQUAL(4u, stats1.size);
EXPECT_EQUAL(1u, stats1.num_less_checks);
@@ -65,7 +65,7 @@ TEST("require that tree stats can be calculated") {
EXPECT_EQUAL(0u, stats1.num_inverted_checks);
EXPECT_EQUAL(3u, stats1.max_set_size);
- TreeStats stats2(Function::parse("if((d in [1]),10.0,if(!(e>=1),20.0,30.0))").root());
+ TreeStats stats2(Function::parse("if((d in [1]),10.0,if(!(e>=1),20.0,30.0))")->root());
EXPECT_EQUAL(2u, stats2.num_params);
EXPECT_EQUAL(3u, stats2.size);
EXPECT_EQUAL(0u, stats2.num_less_checks);
@@ -78,8 +78,8 @@ TEST("require that trees can be extracted from forest") {
for (size_t tree_size = 10; tree_size < 20; ++tree_size) {
for (size_t forest_size = 10; forest_size < 20; ++forest_size) {
vespalib::string expression = Model().make_forest(forest_size, tree_size);
- Function function = Function::parse(expression);
- std::vector<const Node *> trees = extract_trees(function.root());
+ auto function = Function::parse(expression);
+ std::vector<const Node *> trees = extract_trees(function->root());
EXPECT_EQUAL(forest_size, trees.size());
for (const Node *tree: trees) {
EXPECT_EQUAL(tree_size, TreeStats(*tree).size);
@@ -89,10 +89,10 @@ TEST("require that trees can be extracted from forest") {
}
TEST("require that forest stats can be calculated") {
- Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+"
- "if((d in [1]),10.0,if(!(e>=1),20.0,30.0))+"
- "if((a<1),10.0,if(!(e>=1),20.0,30.0))");
- std::vector<const Node *> trees = extract_trees(function.root());
+ auto function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+"
+ "if((d in [1]),10.0,if(!(e>=1),20.0,30.0))+"
+ "if((a<1),10.0,if(!(e>=1),20.0,30.0))");
+ std::vector<const Node *> trees = extract_trees(function->root());
ForestStats stats(trees);
EXPECT_EQUAL(5u, stats.num_params);
EXPECT_EQUAL(3u, stats.num_trees);
@@ -109,7 +109,7 @@ TEST("require that forest stats can be calculated") {
}
double expected_path(const vespalib::string &forest) {
- return ForestStats(extract_trees(Function::parse(forest).root())).total_expected_path_length;
+ return ForestStats(extract_trees(Function::parse(forest)->root())).total_expected_path_length;
}
TEST("require that expected path length is calculated correctly") {
@@ -125,7 +125,7 @@ TEST("require that expected path length is calculated correctly") {
}
double average_path(const vespalib::string &forest) {
- return ForestStats(extract_trees(Function::parse(forest).root())).total_average_path_length;
+ return ForestStats(extract_trees(Function::parse(forest)->root())).total_average_path_length;
}
TEST("require that average path length is calculated correctly") {
@@ -141,7 +141,7 @@ TEST("require that average path length is calculated correctly") {
}
double count_tuned(const vespalib::string &forest) {
- return ForestStats(extract_trees(Function::parse(forest).root())).total_tuned_checks;
+ return ForestStats(extract_trees(Function::parse(forest)->root())).total_tuned_checks;
}
TEST("require that tuned checks are counted correctly") {
@@ -204,11 +204,11 @@ struct DummyForest2 : public Forest {
TEST("require that trees cannot be optimized by a forest optimizer when using SEPARATE params") {
Optimize::Chain chain({DummyForest0::optimize});
- Function function = Function::parse("if((a<1),1.0,if((b<1),if((c<1),2.0,3.0),4.0))+"
- "if((d<1),10.0,if((e<1),if((f<1),20.0,30.0),40.0))");
- CompiledFunction compiled_function(function, PassParams::SEPARATE, chain);
- CompiledFunction compiled_function_array(function, PassParams::ARRAY, chain);
- CompiledFunction compiled_function_lazy(function, PassParams::LAZY, chain);
+ auto function = Function::parse("if((a<1),1.0,if((b<1),if((c<1),2.0,3.0),4.0))+"
+ "if((d<1),10.0,if((e<1),if((f<1),20.0,30.0),40.0))");
+ CompiledFunction compiled_function(*function, PassParams::SEPARATE, chain);
+ CompiledFunction compiled_function_array(*function, PassParams::ARRAY, chain);
+ CompiledFunction compiled_function_lazy(*function, PassParams::LAZY, chain);
EXPECT_EQUAL(0u, compiled_function.get_forests().size());
EXPECT_EQUAL(1u, compiled_function_array.get_forests().size());
EXPECT_EQUAL(1u, compiled_function_lazy.get_forests().size());
@@ -226,12 +226,12 @@ TEST("require that trees can be optimized by a forest optimizer when using ARRAY
size_t tree_size = 20;
for (size_t forest_size = 10; forest_size <= 100; forest_size += 10) {
vespalib::string expression = Model().make_forest(forest_size, tree_size);
- Function function = Function::parse(expression);
- CompiledFunction compiled_function(function, PassParams::ARRAY, chain);
- std::vector<double> inputs(function.num_params(), 0.5);
+ auto function = Function::parse(expression);
+ CompiledFunction compiled_function(*function, PassParams::ARRAY, chain);
+ std::vector<double> inputs(function->num_params(), 0.5);
if (forest_size < 25) {
EXPECT_EQUAL(0u, compiled_function.get_forests().size());
- EXPECT_EQUAL(eval_double(function, inputs), compiled_function.get_function()(&inputs[0]));
+ EXPECT_EQUAL(eval_double(*function, inputs), compiled_function.get_function()(&inputs[0]));
} else if (forest_size < 50) {
EXPECT_EQUAL(1u, compiled_function.get_forests().size());
EXPECT_EQUAL(double(forest_size), compiled_function.get_function()(&inputs[0]));
@@ -247,12 +247,12 @@ TEST("require that trees can be optimized by a forest optimizer when using LAZY
size_t tree_size = 20;
for (size_t forest_size = 10; forest_size <= 100; forest_size += 10) {
vespalib::string expression = Model().make_forest(forest_size, tree_size);
- Function function = Function::parse(expression);
- CompiledFunction compiled_function(function, PassParams::LAZY, chain);
- std::vector<double> inputs(function.num_params(), 0.5);
+ auto function = Function::parse(expression);
+ CompiledFunction compiled_function(*function, PassParams::LAZY, chain);
+ std::vector<double> inputs(function->num_params(), 0.5);
if (forest_size < 25) {
EXPECT_EQUAL(0u, compiled_function.get_forests().size());
- EXPECT_EQUAL(eval_double(function, inputs), compiled_function.get_lazy_function()(my_resolve, &inputs[0]));
+ EXPECT_EQUAL(eval_double(*function, inputs), compiled_function.get_lazy_function()(my_resolve, &inputs[0]));
} else if (forest_size < 50) {
EXPECT_EQUAL(1u, compiled_function.get_forests().size());
EXPECT_EQUAL(double(forest_size), compiled_function.get_lazy_function()(my_resolve, &inputs[0]));
@@ -269,9 +269,9 @@ Optimize::Chain less_only_vm_chain({VMForest::less_only_optimize});
Optimize::Chain general_vm_chain({VMForest::general_optimize});
TEST("require that less only VM tree optimizer works") {
- Function function = Function::parse("if((a<1),1.0,if((b<1),if((c<1),2.0,3.0),4.0))+"
- "if((d<1),10.0,if((e<1),if((f<1),20.0,30.0),40.0))");
- CompiledFunction compiled_function(function, PassParams::ARRAY, less_only_vm_chain);
+ auto function = Function::parse("if((a<1),1.0,if((b<1),if((c<1),2.0,3.0),4.0))+"
+ "if((d<1),10.0,if((e<1),if((f<1),20.0,30.0),40.0))");
+ CompiledFunction compiled_function(*function, PassParams::ARRAY, less_only_vm_chain);
EXPECT_EQUAL(1u, compiled_function.get_forests().size());
auto f = compiled_function.get_function();
EXPECT_EQUAL(11.0, f(&std::vector<double>({0.5, 0.0, 0.0, 0.5, 0.0, 0.0})[0]));
@@ -281,8 +281,8 @@ TEST("require that less only VM tree optimizer works") {
}
TEST("require that models with in checks are rejected by less only vm optimizer") {
- Function function = Function::parse(Model().less_percent(100).make_forest(300, 30));
- auto trees = extract_trees(function.root());
+ auto function = Function::parse(Model().less_percent(100).make_forest(300, 30));
+ auto trees = extract_trees(function->root());
ForestStats stats(trees);
EXPECT_TRUE(Optimize::apply_chain(less_only_vm_chain, stats, trees).valid());
stats.total_in_checks = 1;
@@ -290,8 +290,8 @@ TEST("require that models with in checks are rejected by less only vm optimizer"
}
TEST("require that models with inverted checks are rejected by less only vm optimizer") {
- Function function = Function::parse(Model().less_percent(100).make_forest(300, 30));
- auto trees = extract_trees(function.root());
+ auto function = Function::parse(Model().less_percent(100).make_forest(300, 30));
+ auto trees = extract_trees(function->root());
ForestStats stats(trees);
EXPECT_TRUE(Optimize::apply_chain(less_only_vm_chain, stats, trees).valid());
stats.total_inverted_checks = 1;
@@ -299,9 +299,9 @@ TEST("require that models with inverted checks are rejected by less only vm opti
}
TEST("require that general VM tree optimizer works") {
- Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+"
+ auto function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+"
"if((d in [1]),10.0,if(!(e>=1),if((f<1),20.0,30.0),40.0))");
- CompiledFunction compiled_function(function, PassParams::ARRAY, general_vm_chain);
+ CompiledFunction compiled_function(*function, PassParams::ARRAY, general_vm_chain);
EXPECT_EQUAL(1u, compiled_function.get_forests().size());
auto f = compiled_function.get_function();
EXPECT_EQUAL(11.0, f(&std::vector<double>({0.5, 0.0, 0.0, 1.0, 0.0, 0.0})[0]));
@@ -311,8 +311,8 @@ TEST("require that general VM tree optimizer works") {
}
TEST("require that models with too large sets are rejected by general vm optimizer") {
- Function function = Function::parse(Model().less_percent(80).make_forest(300, 30));
- auto trees = extract_trees(function.root());
+ auto function = Function::parse(Model().less_percent(80).make_forest(300, 30));
+ auto trees = extract_trees(function->root());
ForestStats stats(trees);
EXPECT_TRUE(stats.total_in_checks > 0);
EXPECT_TRUE(Optimize::apply_chain(general_vm_chain, stats, trees).valid());
@@ -321,12 +321,12 @@ TEST("require that models with too large sets are rejected by general vm optimiz
}
TEST("require that FastForest model evaluation works") {
- Function function = Function::parse("if((a<2),1.0,if((b<2),if((c<2),2.0,3.0),4.0))+"
+ auto function = Function::parse("if((a<2),1.0,if((b<2),if((c<2),2.0,3.0),4.0))+"
"if(!(c>=1),10.0,if((a<1),if((b<1),20.0,30.0),40.0))");
- CompiledFunction compiled(function, PassParams::ARRAY, Optimize::none);
+ CompiledFunction compiled(*function, PassParams::ARRAY, Optimize::none);
auto f = compiled.get_function();
EXPECT_TRUE(compiled.get_forests().empty());
- auto forest = FastForest::try_convert(function);
+ auto forest = FastForest::try_convert(*function);
ASSERT_TRUE(forest);
auto ctx = forest->create_context();
std::vector<double> p1({0.5, 0.5, 0.5}); // all true: 1.0 + 10.0
@@ -347,21 +347,21 @@ TEST("require that forests evaluate to approximately the same for all evaluation
for (size_t less_percent: std::vector<size_t>({100, 80})) {
for (size_t invert_percent: std::vector<size_t>({0, 50})) {
vespalib::string expression = Model().less_percent(less_percent).invert_percent(invert_percent).make_forest(num_trees, tree_size);
- Function function = Function::parse(expression);
- auto forest = FastForest::try_convert(function);
+ auto function = Function::parse(expression);
+ auto forest = FastForest::try_convert(*function);
EXPECT_EQUAL(bool(forest), bool(less_percent == 100));
- CompiledFunction none(function, pass_params, Optimize::none);
- CompiledFunction deinline(function, pass_params, DeinlineForest::optimize_chain);
- CompiledFunction vm_forest(function, pass_params, VMForest::optimize_chain);
+ CompiledFunction none(*function, pass_params, Optimize::none);
+ CompiledFunction deinline(*function, pass_params, DeinlineForest::optimize_chain);
+ CompiledFunction vm_forest(*function, pass_params, VMForest::optimize_chain);
EXPECT_EQUAL(0u, none.get_forests().size());
ASSERT_EQUAL(1u, deinline.get_forests().size());
EXPECT_TRUE(dynamic_cast<DeinlineForest*>(deinline.get_forests()[0].get()) != nullptr);
ASSERT_EQUAL(1u, vm_forest.get_forests().size());
EXPECT_TRUE(dynamic_cast<VMForest*>(vm_forest.get_forests()[0].get()) != nullptr);
- std::vector<double> inputs(function.num_params(), 0.5);
- std::vector<double> inputs_nan(function.num_params(), std::numeric_limits<double>::quiet_NaN());
- double expected = eval_double(function, inputs);
- double expected_nan = eval_double(function, inputs_nan);
+ std::vector<double> inputs(function->num_params(), 0.5);
+ std::vector<double> inputs_nan(function->num_params(), std::numeric_limits<double>::quiet_NaN());
+ double expected = eval_double(*function, inputs);
+ double expected_nan = eval_double(*function, inputs_nan);
EXPECT_EQUAL(expected, eval_compiled(none, inputs));
EXPECT_EQUAL(expected, eval_compiled(deinline, inputs));
EXPECT_EQUAL(expected, eval_compiled(vm_forest, inputs));
@@ -387,15 +387,15 @@ TEST("require that fast forest evaluation is correct for all tree size categorie
for (size_t less_percent: std::vector<size_t>({100})) {
for (size_t invert_percent: std::vector<size_t>({50})) {
vespalib::string expression = Model().max_features(num_features).less_percent(less_percent).invert_percent(invert_percent).make_forest(num_trees, tree_size);
- Function function = Function::parse(expression);
- auto forest = FastForest::try_convert(function);
+ auto function = Function::parse(expression);
+ auto forest = FastForest::try_convert(*function);
if ((tree_size <= 64) || is_little_endian()) {
ASSERT_TRUE(forest);
TEST_STATE(forest->impl_name().c_str());
- std::vector<double> inputs(function.num_params(), 0.5);
- std::vector<double> inputs_nan(function.num_params(), std::numeric_limits<double>::quiet_NaN());
- double expected = eval_double(function, inputs);
- double expected_nan = eval_double(function, inputs_nan);
+ std::vector<double> inputs(function->num_params(), 0.5);
+ std::vector<double> inputs_nan(function->num_params(), std::numeric_limits<double>::quiet_NaN());
+ double expected = eval_double(*function, inputs);
+ double expected_nan = eval_double(*function, inputs_nan);
auto ctx = forest->create_context();
EXPECT_EQUAL(expected, eval_ff(*forest, *ctx, inputs));
EXPECT_EQUAL(expected_nan, eval_ff(*forest, *ctx, inputs_nan));
@@ -410,31 +410,31 @@ TEST("require that fast forest evaluation is correct for all tree size categorie
//-----------------------------------------------------------------------------
TEST("require that GDBT expressions can be detected") {
- Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+"
- "if((d in [1]),10.0,if(!(e>=1),20.0,30.0))+"
- "if((d in [1]),10.0,if(!(e>=1),20.0,30.0))");
- EXPECT_TRUE(contains_gbdt(function.root(), 9));
- EXPECT_TRUE(!contains_gbdt(function.root(), 10));
+ auto function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+"
+ "if((d in [1]),10.0,if(!(e>=1),20.0,30.0))+"
+ "if((d in [1]),10.0,if(!(e>=1),20.0,30.0))");
+ EXPECT_TRUE(contains_gbdt(function->root(), 9));
+ EXPECT_TRUE(!contains_gbdt(function->root(), 10));
}
TEST("require that wrapped GDBT expressions can be detected") {
- Function function = Function::parse("10*(if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+"
- "if((d in [1]),10.0,if((e<1),20.0,30.0))+"
- "if((d in [1]),10.0,if((e<1),20.0,30.0)))");
- EXPECT_TRUE(contains_gbdt(function.root(), 9));
- EXPECT_TRUE(!contains_gbdt(function.root(), 10));
+ auto function = Function::parse("10*(if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+"
+ "if((d in [1]),10.0,if((e<1),20.0,30.0))+"
+ "if((d in [1]),10.0,if((e<1),20.0,30.0)))");
+ EXPECT_TRUE(contains_gbdt(function->root(), 9));
+ EXPECT_TRUE(!contains_gbdt(function->root(), 10));
}
TEST("require that lazy parameters are not suggested for GBDT models") {
- Function function = Function::parse(Model().make_forest(10, 8));
- EXPECT_TRUE(!CompiledFunction::should_use_lazy_params(function));
+ auto function = Function::parse(Model().make_forest(10, 8));
+ EXPECT_TRUE(!CompiledFunction::should_use_lazy_params(*function));
}
TEST("require that lazy parameters can be suggested for small GBDT models") {
- Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+"
- "if((d in [1]),10.0,if((e<1),20.0,30.0))+"
- "if((d in [1]),10.0,if((e<1),20.0,30.0))");
- EXPECT_TRUE(CompiledFunction::should_use_lazy_params(function));
+ auto function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+"
+ "if((d in [1]),10.0,if((e<1),20.0,30.0))+"
+ "if((d in [1]),10.0,if((e<1),20.0,30.0))");
+ EXPECT_TRUE(CompiledFunction::should_use_lazy_params(*function));
}
//-----------------------------------------------------------------------------
diff --git a/eval/src/tests/eval/gbdt/model.cpp b/eval/src/tests/eval/gbdt/model.cpp
index 8f0d87a4020..c2507d2f056 100644
--- a/eval/src/tests/eval/gbdt/model.cpp
+++ b/eval/src/tests/eval/gbdt/model.cpp
@@ -113,7 +113,7 @@ struct ForestParams {
//-----------------------------------------------------------------------------
-Function make_forest(const ForestParams &params, size_t num_trees) {
+auto make_forest(const ForestParams &params, size_t num_trees) {
return Function::parse(Model(params.model_seed)
.less_percent(params.less_percent)
.make_forest(num_trees, params.tree_size));
diff --git a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
index c9108ee74ce..d946d244d17 100644
--- a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
+++ b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
@@ -28,10 +28,10 @@ struct MyEvalTest : test::EvalSpec::EvalTest {
virtual void next_expression(const std::vector<vespalib::string> &param_names,
const vespalib::string &expression) override
{
- Function function = Function::parse(param_names, expression);
- ASSERT_TRUE(!function.has_error());
+ auto function = Function::parse(param_names, expression);
+ ASSERT_TRUE(!function->has_error());
bool is_supported = true;
- bool has_issues = InterpretedFunction::detect_issues(function);
+ bool has_issues = InterpretedFunction::detect_issues(*function);
if (is_supported == has_issues) {
const char *supported_str = is_supported ? "supported" : "not supported";
const char *issues_str = has_issues ? "has issues" : "does not have issues";
@@ -46,16 +46,16 @@ struct MyEvalTest : test::EvalSpec::EvalTest {
const vespalib::string &expression,
double expected_result) override
{
- Function function = Function::parse(param_names, expression);
- ASSERT_TRUE(!function.has_error());
+ auto function = Function::parse(param_names, expression);
+ ASSERT_TRUE(!function->has_error());
bool is_supported = true;
- bool has_issues = InterpretedFunction::detect_issues(function);
+ bool has_issues = InterpretedFunction::detect_issues(*function);
if (is_supported && !has_issues) {
vespalib::string desc = as_string(param_names, param_values, expression);
SimpleParams params(param_values);
- verify_result(SimpleTensorEngine::ref(), function, false, "[untyped simple] "+desc, params, expected_result);
- verify_result(DefaultTensorEngine::ref(), function, false, "[untyped prod] "+desc, params, expected_result);
- verify_result(DefaultTensorEngine::ref(), function, true, "[typed prod] "+desc, params, expected_result);
+ verify_result(SimpleTensorEngine::ref(), *function, false, "[untyped simple] "+desc, params, expected_result);
+ verify_result(DefaultTensorEngine::ref(), *function, false, "[untyped prod] "+desc, params, expected_result);
+ verify_result(DefaultTensorEngine::ref(), *function, true, "[typed prod] "+desc, params, expected_result);
}
}
@@ -102,15 +102,15 @@ TEST_FF("require that compiled evaluation passes all conformance tests", MyEvalT
TEST("require that invalid function is tagged with error") {
std::vector<vespalib::string> params({"x", "y", "z", "w"});
- Function function = Function::parse(params, "x & y");
- EXPECT_TRUE(function.has_error());
+ auto function = Function::parse(params, "x & y");
+ EXPECT_TRUE(function->has_error());
}
//-----------------------------------------------------------------------------
size_t count_ifs(const vespalib::string &expr, std::initializer_list<double> params_in) {
- Function fun = Function::parse(expr);
- InterpretedFunction ifun(SimpleTensorEngine::ref(), fun, NodeTypes());
+ auto fun = Function::parse(expr);
+ InterpretedFunction ifun(SimpleTensorEngine::ref(), *fun, NodeTypes());
InterpretedFunction::Context ctx(ifun);
SimpleParams params(params_in);
ifun.eval(ctx, params);
@@ -137,8 +137,8 @@ TEST("require that function pointers can be passed as instruction parameters") {
}
TEST("require that basic addition works") {
- Function function = Function::parse("a+10");
- InterpretedFunction interpreted(SimpleTensorEngine::ref(), function, NodeTypes());
+ auto function = Function::parse("a+10");
+ InterpretedFunction interpreted(SimpleTensorEngine::ref(), *function, NodeTypes());
InterpretedFunction::Context ctx(interpreted);
SimpleParams params_20({20});
SimpleParams params_40({40});
@@ -153,20 +153,20 @@ TEST("require that functions with non-compilable lambdas cannot be interpreted")
auto good_join = Function::parse("join(a,b,f(x,y)(x+y))");
auto bad_map = Function::parse("map(a,f(x)(map(x,f(i)(i+1))))");
auto bad_join = Function::parse("join(a,b,f(x,y)(join(x,y,f(i,j)(i+j))))");
- for (const Function *good: {&good_map, &good_join}) {
+ for (const Function *good: {good_map.get(), good_join.get()}) {
if (!EXPECT_TRUE(!good->has_error())) {
fprintf(stderr, "parse error: %s\n", good->get_error().c_str());
}
EXPECT_TRUE(!InterpretedFunction::detect_issues(*good));
}
- for (const Function *bad: {&bad_map, &bad_join}) {
+ for (const Function *bad: {bad_map.get(), bad_join.get()}) {
if (!EXPECT_TRUE(!bad->has_error())) {
fprintf(stderr, "parse error: %s\n", bad->get_error().c_str());
}
EXPECT_TRUE(InterpretedFunction::detect_issues(*bad));
}
std::cerr << "Example function issues:" << std::endl
- << InterpretedFunction::detect_issues(bad_join).list
+ << InterpretedFunction::detect_issues(*bad_join).list
<< std::endl;
}
diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp
index 6fbcea8a8a1..dceaf279594 100644
--- a/eval/src/tests/eval/node_types/node_types_test.cpp
+++ b/eval/src/tests/eval/node_types/node_types_test.cpp
@@ -19,27 +19,27 @@ struct TypeSpecExtractor : public vespalib::eval::SymbolExtractor {
};
void verify(const vespalib::string &type_expr, const vespalib::string &type_spec) {
- Function function = Function::parse(type_expr, TypeSpecExtractor());
- if (!EXPECT_TRUE(!function.has_error())) {
- fprintf(stderr, "parse error: %s\n", function.get_error().c_str());
+ auto function = Function::parse(type_expr, TypeSpecExtractor());
+ if (!EXPECT_TRUE(!function->has_error())) {
+ fprintf(stderr, "parse error: %s\n", function->get_error().c_str());
return;
}
std::vector<ValueType> input_types;
- for (size_t i = 0; i < function.num_params(); ++i) {
- input_types.push_back(ValueType::from_spec(function.param_name(i)));
+ for (size_t i = 0; i < function->num_params(); ++i) {
+ input_types.push_back(ValueType::from_spec(function->param_name(i)));
}
- NodeTypes types(function, input_types);
+ NodeTypes types(*function, input_types);
ValueType expected_type = ValueType::from_spec(type_spec);
- ValueType actual_type = types.get_type(function.root());
+ ValueType actual_type = types.get_type(function->root());
EXPECT_EQUAL(expected_type, actual_type);
}
TEST("require that error nodes have error type") {
- Function function = Function::parse("1 2 3 4 5", TypeSpecExtractor());
- EXPECT_TRUE(function.has_error());
- NodeTypes types(function, std::vector<ValueType>());
+ auto function = Function::parse("1 2 3 4 5", TypeSpecExtractor());
+ EXPECT_TRUE(function->has_error());
+ NodeTypes types(*function, std::vector<ValueType>());
ValueType expected_type = ValueType::from_spec("error");
- ValueType actual_type = types.get_type(function.root());
+ ValueType actual_type = types.get_type(function->root());
EXPECT_EQUAL(expected_type, actual_type);
}
@@ -252,21 +252,21 @@ TEST("require that tensor concat resolves correct type") {
}
TEST("require that double only expressions can be detected") {
- Function plain_fun = Function::parse("1+2");
- Function complex_fun = Function::parse("reduce(a,sum)");
- NodeTypes plain_types(plain_fun, {});
- NodeTypes complex_types(complex_fun, {ValueType::tensor_type({{"x"}})});
- EXPECT_TRUE(plain_types.get_type(plain_fun.root()).is_double());
- EXPECT_TRUE(complex_types.get_type(complex_fun.root()).is_double());
+ auto plain_fun = Function::parse("1+2");
+ auto complex_fun = Function::parse("reduce(a,sum)");
+ NodeTypes plain_types(*plain_fun, {});
+ NodeTypes complex_types(*complex_fun, {ValueType::tensor_type({{"x"}})});
+ EXPECT_TRUE(plain_types.get_type(plain_fun->root()).is_double());
+ EXPECT_TRUE(complex_types.get_type(complex_fun->root()).is_double());
EXPECT_TRUE(plain_types.all_types_are_double());
EXPECT_FALSE(complex_types.all_types_are_double());
}
TEST("require that empty type repo works as expected") {
NodeTypes types;
- Function function = Function::parse("1+2");
- EXPECT_FALSE(function.has_error());
- EXPECT_TRUE(types.get_type(function.root()).is_error());
+ auto function = Function::parse("1+2");
+ EXPECT_FALSE(function->has_error());
+ EXPECT_TRUE(types.get_type(function->root()).is_error());
EXPECT_FALSE(types.all_types_are_double());
}
diff --git a/eval/src/tests/eval/param_usage/param_usage_test.cpp b/eval/src/tests/eval/param_usage/param_usage_test.cpp
index c81be24a9ec..4a8278511b3 100644
--- a/eval/src/tests/eval/param_usage/param_usage_test.cpp
+++ b/eval/src/tests/eval/param_usage/param_usage_test.cpp
@@ -30,47 +30,47 @@ std::ostream &operator<<(std::ostream &out, const List &list) {
TEST("require that simple expression has appropriate parameter usage") {
std::vector<vespalib::string> params({"x", "y", "z"});
- Function function = Function::parse(params, "(x+y)*y");
- EXPECT_EQUAL(List(count_param_usage(function)), List({1.0, 2.0, 0.0}));
- EXPECT_EQUAL(List(check_param_usage(function)), List({1.0, 1.0, 0.0}));
+ auto function = Function::parse(params, "(x+y)*y");
+ EXPECT_EQUAL(List(count_param_usage(*function)), List({1.0, 2.0, 0.0}));
+ EXPECT_EQUAL(List(check_param_usage(*function)), List({1.0, 1.0, 0.0}));
}
TEST("require that if children have 50% probability each by default") {
std::vector<vespalib::string> params({"x", "y", "z", "w"});
- Function function = Function::parse(params, "if(w,(x+y)*y,(y+z)*z)");
- EXPECT_EQUAL(List(count_param_usage(function)), List({0.5, 1.5, 1.0, 1.0}));
- EXPECT_EQUAL(List(check_param_usage(function)), List({0.5, 1.0, 0.5, 1.0}));
+ auto function = Function::parse(params, "if(w,(x+y)*y,(y+z)*z)");
+ EXPECT_EQUAL(List(count_param_usage(*function)), List({0.5, 1.5, 1.0, 1.0}));
+ EXPECT_EQUAL(List(check_param_usage(*function)), List({0.5, 1.0, 0.5, 1.0}));
}
TEST("require that if children probability can be adjusted") {
std::vector<vespalib::string> params({"x", "y", "z"});
- Function function = Function::parse(params, "if(z,x*x,y*y,0.8)");
- EXPECT_EQUAL(List(count_param_usage(function)), List({1.6, 0.4, 1.0}));
- EXPECT_EQUAL(List(check_param_usage(function)), List({0.8, 0.2, 1.0}));
+ auto function = Function::parse(params, "if(z,x*x,y*y,0.8)");
+ EXPECT_EQUAL(List(count_param_usage(*function)), List({1.6, 0.4, 1.0}));
+ EXPECT_EQUAL(List(check_param_usage(*function)), List({0.8, 0.2, 1.0}));
}
TEST("require that chained if statements are combined correctly") {
std::vector<vespalib::string> params({"x", "y", "z", "w"});
- Function function = Function::parse(params, "if(z,x,y)+if(w,y,x)");
- EXPECT_EQUAL(List(count_param_usage(function)), List({1.0, 1.0, 1.0, 1.0}));
- EXPECT_EQUAL(List(check_param_usage(function)), List({0.75, 0.75, 1.0, 1.0}));
+ auto function = Function::parse(params, "if(z,x,y)+if(w,y,x)");
+ EXPECT_EQUAL(List(count_param_usage(*function)), List({1.0, 1.0, 1.0, 1.0}));
+ EXPECT_EQUAL(List(check_param_usage(*function)), List({0.75, 0.75, 1.0, 1.0}));
}
TEST("require that multi-level if statements are combined correctly") {
std::vector<vespalib::string> params({"x", "y", "z", "w"});
- Function function = Function::parse(params, "if(z,if(w,y*x,x*x),if(w,y*x,x*x))");
- EXPECT_EQUAL(List(count_param_usage(function)), List({1.5, 0.5, 1.0, 1.0}));
- EXPECT_EQUAL(List(check_param_usage(function)), List({1.0, 0.5, 1.0, 1.0}));
+ auto function = Function::parse(params, "if(z,if(w,y*x,x*x),if(w,y*x,x*x))");
+ EXPECT_EQUAL(List(count_param_usage(*function)), List({1.5, 0.5, 1.0, 1.0}));
+ EXPECT_EQUAL(List(check_param_usage(*function)), List({1.0, 0.5, 1.0, 1.0}));
}
TEST("require that lazy parameters are suggested for functions with parameters that might not be used") {
- Function function = Function::parse("if(z,x,y)+if(w,y,x)");
- EXPECT_TRUE(CompiledFunction::should_use_lazy_params(function));
+ auto function = Function::parse("if(z,x,y)+if(w,y,x)");
+ EXPECT_TRUE(CompiledFunction::should_use_lazy_params(*function));
}
TEST("require that lazy parameters are not suggested for functions where all parameters are always used") {
- Function function = Function::parse("a*b*c");
- EXPECT_TRUE(!CompiledFunction::should_use_lazy_params(function));
+ auto function = Function::parse("a*b*c");
+ EXPECT_TRUE(!CompiledFunction::should_use_lazy_params(*function));
}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/vespa/eval/eval/basic_nodes.h b/eval/src/vespa/eval/eval/basic_nodes.h
index 06c296fe753..c1192585f7c 100644
--- a/eval/src/vespa/eval/eval/basic_nodes.h
+++ b/eval/src/vespa/eval/eval/basic_nodes.h
@@ -61,7 +61,7 @@ struct Node {
virtual ~Node() {}
};
-typedef std::unique_ptr<Node> Node_UP;
+using Node_UP = std::unique_ptr<Node>;
/**
* Simple typecasting utility. Intended usage:
diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp
index baaaff25b34..108380f52b7 100644
--- a/eval/src/vespa/eval/eval/function.cpp
+++ b/eval/src/vespa/eval/eval/function.cpp
@@ -565,7 +565,7 @@ std::vector<vespalib::string> get_idents(ParseContext &ctx) {
return list;
}
-Function parse_lambda(ParseContext &ctx, size_t num_params) {
+auto parse_lambda(ParseContext &ctx, size_t num_params) {
ctx.skip_spaces();
ctx.eat('f');
auto param_names = get_ident_list(ctx, true);
@@ -581,13 +581,13 @@ Function parse_lambda(ParseContext &ctx, size_t num_params) {
ctx.fail(make_string("expected lambda with %zu parameter(s), was %zu",
num_params, param_names.size()));
}
- return Function(std::move(lambda_root), std::move(param_names));
+ return Function::create(std::move(lambda_root), std::move(param_names));
}
void parse_tensor_map(ParseContext &ctx) {
Node_UP child = get_expression(ctx);
ctx.eat(',');
- Function lambda = parse_lambda(ctx, 1);
+ auto lambda = parse_lambda(ctx, 1);
ctx.push_expression(std::make_unique<nodes::TensorMap>(std::move(child), std::move(lambda)));
}
@@ -596,7 +596,7 @@ void parse_tensor_join(ParseContext &ctx) {
ctx.eat(',');
Node_UP rhs = get_expression(ctx);
ctx.eat(',');
- Function lambda = parse_lambda(ctx, 2);
+ auto lambda = parse_lambda(ctx, 2);
ctx.push_expression(std::make_unique<nodes::TensorJoin>(std::move(lhs), std::move(rhs), std::move(lambda)));
}
@@ -994,15 +994,15 @@ void parse_expression(ParseContext &ctx) {
}
}
-Function parse_function(const Params &params, vespalib::stringref expression,
- const SymbolExtractor *symbol_extractor)
+auto parse_function(const Params &params, vespalib::stringref expression,
+ const SymbolExtractor *symbol_extractor)
{
ParseContext ctx(params, expression.data(), expression.size(), symbol_extractor);
parse_expression(ctx);
if (ctx.failed() && params.implicit()) {
- return Function(ctx.get_result(), std::vector<vespalib::string>());
+ return Function::create(ctx.get_result(), std::vector<vespalib::string>());
}
- return Function(ctx.get_result(), params.extract());
+ return Function::create(ctx.get_result(), params.extract());
}
} // namespace vespalib::<unnamed>
@@ -1023,25 +1023,31 @@ Function::get_error() const
return error ? error->message() : "";
}
-Function
+std::shared_ptr<Function const>
+Function::create(nodes::Node_UP root_in, std::vector<vespalib::string> params_in)
+{
+ return std::make_shared<Function const>(std::move(root_in), std::move(params_in), ctor_tag());
+}
+
+std::shared_ptr<Function const>
Function::parse(vespalib::stringref expression)
{
return parse_function(ImplicitParams(), expression, nullptr);
}
-Function
+std::shared_ptr<Function const>
Function::parse(vespalib::stringref expression, const SymbolExtractor &symbol_extractor)
{
return parse_function(ImplicitParams(), expression, &symbol_extractor);
}
-Function
+std::shared_ptr<Function const>
Function::parse(const std::vector<vespalib::string> &params, vespalib::stringref expression)
{
return parse_function(ExplicitParams(params), expression, nullptr);
}
-Function
+std::shared_ptr<Function const>
Function::parse(const std::vector<vespalib::string> &params, vespalib::stringref expression,
const SymbolExtractor &symbol_extractor)
{
diff --git a/eval/src/vespa/eval/eval/function.h b/eval/src/vespa/eval/eval/function.h
index b3505aad141..b9b9d091961 100644
--- a/eval/src/vespa/eval/eval/function.h
+++ b/eval/src/vespa/eval/eval/function.h
@@ -34,28 +34,33 @@ struct NodeVisitor;
* AST root and the names of all parameters. A function can only be
* evaluated using the appropriate number of parameters.
**/
-class Function
+class Function : public std::enable_shared_from_this<Function>
{
private:
nodes::Node_UP _root;
std::vector<vespalib::string> _params;
+ struct ctor_tag {};
+
public:
- Function() : _root(new nodes::Number(0.0)), _params() {}
- Function(nodes::Node_UP root_in, std::vector<vespalib::string> &&params_in)
+ Function(nodes::Node_UP root_in, std::vector<vespalib::string> params_in, ctor_tag)
: _root(std::move(root_in)), _params(std::move(params_in)) {}
- Function(Function &&rhs) : _root(std::move(rhs._root)), _params(std::move(rhs._params)) {}
+ Function(Function &&rhs) = delete;
+ Function(const Function &rhs) = delete;
+ Function &operator=(Function &&rhs) = delete;
+ Function &operator=(const Function &rhs) = delete;
~Function() { delete_node(std::move(_root)); }
size_t num_params() const { return _params.size(); }
vespalib::stringref param_name(size_t idx) const { return _params[idx]; }
bool has_error() const;
vespalib::string get_error() const;
const nodes::Node &root() const { return *_root; }
- static Function parse(vespalib::stringref expression);
- static Function parse(vespalib::stringref expression, const SymbolExtractor &symbol_extractor);
- static Function parse(const std::vector<vespalib::string> &params, vespalib::stringref expression);
- static Function parse(const std::vector<vespalib::string> &params, vespalib::stringref expression,
- const SymbolExtractor &symbol_extractor);
+ static std::shared_ptr<Function const> create(nodes::Node_UP root_in, std::vector<vespalib::string> params_in);
+ static std::shared_ptr<Function const> parse(vespalib::stringref expression);
+ static std::shared_ptr<Function const> parse(vespalib::stringref expression, const SymbolExtractor &symbol_extractor);
+ static std::shared_ptr<Function const> parse(const std::vector<vespalib::string> &params, vespalib::stringref expression);
+ static std::shared_ptr<Function const> parse(const std::vector<vespalib::string> &params, vespalib::stringref expression,
+ const SymbolExtractor &symbol_extractor);
vespalib::string dump() const {
nodes::DumpContext dump_context(_params);
return _root->dump(dump_context);
diff --git a/eval/src/vespa/eval/eval/llvm/compile_cache.cpp b/eval/src/vespa/eval/eval/llvm/compile_cache.cpp
index f4dfad32121..4f93c496f2a 100644
--- a/eval/src/vespa/eval/eval/llvm/compile_cache.cpp
+++ b/eval/src/vespa/eval/eval/llvm/compile_cache.cpp
@@ -2,13 +2,21 @@
#include "compile_cache.h"
#include <vespa/eval/eval/key_gen.h>
-#include <thread>
namespace vespalib {
namespace eval {
-std::mutex CompileCache::_lock;
-CompileCache::Map CompileCache::_cached;
+std::mutex CompileCache::_lock{};
+CompileCache::Map CompileCache::_cached{};
+Executor *CompileCache::_executor{nullptr};
+
+const CompiledFunction &
+CompileCache::Value::wait_for_result()
+{
+ std::unique_lock<std::mutex> guard(_lock);
+ cond.wait(guard, [this](){ return bool(compiled_function); });
+ return *compiled_function;
+}
void
CompileCache::release(Map::iterator entry)
@@ -22,11 +30,45 @@ CompileCache::release(Map::iterator entry)
CompileCache::Token::UP
CompileCache::compile(const Function &function, PassParams pass_params)
{
+ Token::UP token;
+ CompileTask::UP task;
+ vespalib::string key = gen_key(function, pass_params);
+ {
+ std::lock_guard<std::mutex> guard(_lock);
+ auto pos = _cached.find(key);
+ if (pos != _cached.end()) {
+ ++(pos->second.num_refs);
+ token = std::make_unique<Token>(pos, Token::ctor_tag());
+ } else {
+ auto res = _cached.emplace(std::move(key), Value::ctor_tag());
+ assert(res.second);
+ token = std::make_unique<Token>(res.first, Token::ctor_tag());
+ ++(res.first->second.num_refs);
+ task = std::make_unique<CompileTask>(function, pass_params,
+ std::make_unique<Token>(res.first, Token::ctor_tag()));
+ if (_executor != nullptr) {
+ task = _executor->execute(std::move(task));
+ }
+ }
+ }
+ if (task) {
+ task->run();
+ }
+ return token;
+}
+
+void
+CompileCache::attach_executor(Executor &executor)
+{
+ std::lock_guard<std::mutex> guard(_lock);
+ _executor = &executor;
+}
+
+void
+CompileCache::detach_executor()
+{
std::lock_guard<std::mutex> guard(_lock);
- CompileContext compile_ctx(function, pass_params);
- std::thread thread(do_compile, std::ref(compile_ctx));
- thread.join();
- return std::move(compile_ctx.token);
+ _executor = nullptr;
}
size_t
@@ -47,18 +89,28 @@ CompileCache::count_refs()
return refs;
}
-void
-CompileCache::do_compile(CompileContext &ctx) {
- vespalib::string key = gen_key(ctx.function, ctx.pass_params);
- auto pos = _cached.find(key);
- if (pos != _cached.end()) {
- ++(pos->second.num_refs);
- ctx.token.reset(new Token(pos));
- } else {
- auto res = _cached.emplace(std::move(key), Value(CompiledFunction(ctx.function, ctx.pass_params)));
- assert(res.second);
- ctx.token.reset(new Token(res.first));
+size_t
+CompileCache::count_pending()
+{
+ std::lock_guard<std::mutex> guard(_lock);
+ size_t pending = 0;
+ for (const auto &entry: _cached) {
+ if (entry.second.compiled_function.get() == nullptr) {
+ ++pending;
+ }
}
+ return pending;
+}
+
+void
+CompileCache::CompileTask::run()
+{
+ auto &entry = token->_entry->second;
+ auto result = std::make_unique<CompiledFunction>(*function, pass_params);
+ std::lock_guard<std::mutex> guard(_lock);
+ entry.compiled_function = std::move(result);
+ entry.cf.store(entry.compiled_function.get(), std::memory_order_release);
+ entry.cond.notify_all();
}
} // namespace vespalib::eval
diff --git a/eval/src/vespa/eval/eval/llvm/compile_cache.h b/eval/src/vespa/eval/eval/llvm/compile_cache.h
index 3fa721b27bf..0adca808256 100644
--- a/eval/src/vespa/eval/eval/llvm/compile_cache.h
+++ b/eval/src/vespa/eval/eval/llvm/compile_cache.h
@@ -3,6 +3,8 @@
#pragma once
#include "compiled_function.h"
+#include <vespa/vespalib/util/executor.h>
+#include <condition_variable>
#include <mutex>
namespace vespalib {
@@ -19,15 +21,27 @@ namespace eval {
class CompileCache
{
private:
- typedef vespalib::string Key;
+ using Key = vespalib::string;
struct Value {
size_t num_refs;
- CompiledFunction cf;
- Value(CompiledFunction &&cf_in) : num_refs(1), cf(std::move(cf_in)) {}
+ std::atomic<const CompiledFunction *> cf;
+ std::condition_variable cond;
+ CompiledFunction::UP compiled_function;
+ struct ctor_tag {};
+ Value(ctor_tag) : num_refs(1), cf(nullptr), cond(), compiled_function() {}
+ const CompiledFunction &wait_for_result();
+ const CompiledFunction &get() {
+ const CompiledFunction *ptr = cf.load(std::memory_order_acquire);
+ if (ptr == nullptr) {
+ return wait_for_result();
+ }
+ return *ptr;
+ }
};
- typedef std::map<Key,Value> Map;
+ using Map = std::map<Key,Value>;
static std::mutex _lock;
static Map _cached;
+ static Executor *_executor;
static void release(Map::iterator entry);
@@ -36,33 +50,37 @@ public:
{
private:
friend class CompileCache;
- CompileCache::Map::iterator entry;
- explicit Token(CompileCache::Map::iterator entry_in)
- : entry(entry_in) {}
+ friend class CompileTask;
+ struct ctor_tag {};
+ CompileCache::Map::iterator _entry;
public:
- typedef std::unique_ptr<Token> UP;
- const CompiledFunction &get() const { return entry->second.cf; }
- ~Token() { CompileCache::release(entry); }
+ Token(Token &&) = delete;
+ Token(const Token &) = delete;
+ Token &operator=(Token &&) = delete;
+ Token &operator=(const Token &) = delete;
+ using UP = std::unique_ptr<Token>;
+ explicit Token(CompileCache::Map::iterator entry, ctor_tag) : _entry(entry) {}
+ const CompiledFunction &get() const { return _entry->second.get(); }
+ ~Token() { CompileCache::release(_entry); }
};
+
static Token::UP compile(const Function &function, PassParams pass_params);
+ static void attach_executor(Executor &executor);
+ static void detach_executor();
static size_t num_cached();
static size_t count_refs();
+ static size_t count_pending();
private:
- struct CompileContext {
- const Function &function;
+ struct CompileTask : public Executor::Task {
+ std::shared_ptr<Function const> function;
PassParams pass_params;
Token::UP token;
- CompileContext(const Function &function_in,
- PassParams pass_params_in)
- : function(function_in),
- pass_params(pass_params_in),
- token() {}
+ CompileTask(const Function &function_in, PassParams pass_params_in, Token::UP token_in)
+ : function(function_in.shared_from_this()), pass_params(pass_params_in), token(std::move(token_in)) {}
+ void run() override;
};
-
- static void do_compile(CompileContext &ctx);
};
} // namespace vespalib::eval
} // namespace vespalib
-
diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
index e5b7d45b0e2..72406f09a0d 100644
--- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
+++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
@@ -643,6 +643,7 @@ FunctionBuilder::~FunctionBuilder() { }
struct InitializeNativeTarget {
InitializeNativeTarget() {
+ assert(llvm::llvm_is_multithreaded());
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();
@@ -658,8 +659,6 @@ struct InitializeNativeTarget {
}
} initialize_native_target;
-std::recursive_mutex LLVMWrapper::_global_llvm_lock;
-
LLVMWrapper::LLVMWrapper()
: _context(),
_module(),
@@ -668,17 +667,14 @@ LLVMWrapper::LLVMWrapper()
_forests(),
_plugin_state()
{
- std::lock_guard<std::recursive_mutex> guard(_global_llvm_lock);
_context = std::make_unique<llvm::LLVMContext>();
- _module = std::make_unique< llvm::Module>("LLVMWrapper", *_context);
+ _module = std::make_unique<llvm::Module>("LLVMWrapper", *_context);
}
-
size_t
LLVMWrapper::make_function(size_t num_params, PassParams pass_params, const Node &root,
const gbdt::Optimize::Chain &forest_optimizers)
{
- std::lock_guard<std::recursive_mutex> guard(_global_llvm_lock);
size_t function_id = _functions.size();
FunctionBuilder builder(*_context, *_module,
vespalib::make_string("f%zu", function_id),
@@ -692,7 +688,6 @@ LLVMWrapper::make_function(size_t num_params, PassParams pass_params, const Node
size_t
LLVMWrapper::make_forest_fragment(size_t num_params, const std::vector<const Node *> &fragment)
{
- std::lock_guard<std::recursive_mutex> guard(_global_llvm_lock);
size_t function_id = _functions.size();
FunctionBuilder builder(*_context, *_module,
vespalib::make_string("f%zu", function_id),
@@ -706,7 +701,6 @@ LLVMWrapper::make_forest_fragment(size_t num_params, const std::vector<const Nod
void
LLVMWrapper::compile(llvm::raw_ostream * dumpStream)
{
- std::lock_guard<std::recursive_mutex> guard(_global_llvm_lock);
if (dumpStream) {
_module->print(*dumpStream, nullptr);
}
@@ -718,12 +712,10 @@ LLVMWrapper::compile(llvm::raw_ostream * dumpStream)
void *
LLVMWrapper::get_function_address(size_t function_id)
{
- std::lock_guard<std::recursive_mutex> guard(_global_llvm_lock);
return _engine->getPointerToFunction(_functions[function_id]);
}
LLVMWrapper::~LLVMWrapper() {
- std::lock_guard<std::recursive_mutex> guard(_global_llvm_lock);
_plugin_state.clear();
_forests.clear();
_functions.clear();
diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h
index db29771ee9e..040c0bdb73f 100644
--- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h
+++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h
@@ -49,8 +49,6 @@ private:
std::vector<gbdt::Forest::UP> _forests;
std::vector<PluginState::UP> _plugin_state;
- static std::recursive_mutex _global_llvm_lock;
-
void compile(llvm::raw_ostream * dumpStream);
public:
LLVMWrapper();
diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp
index a7215eabcb9..803047d27c4 100644
--- a/eval/src/vespa/eval/eval/make_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp
@@ -167,8 +167,8 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
for (size_t i = 0; i < node.num_entries(); ++i) {
my_in->add_entry(std::make_unique<Number>(node.get_entry(i).get_const_value()));
}
- Function my_fun(std::move(my_in), {"x"});
- const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(my_fun, PassParams::SEPARATE));
+ auto my_fun = Function::create(std::move(my_in), {"x"});
+ const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(*my_fun, PassParams::SEPARATE));
make_map(node, token.get()->get().get_function<1>());
}
void visit(const Neg &node) override {
diff --git a/eval/src/vespa/eval/eval/tensor_nodes.h b/eval/src/vespa/eval/eval/tensor_nodes.h
index a307faf9b36..4213809f9e3 100644
--- a/eval/src/vespa/eval/eval/tensor_nodes.h
+++ b/eval/src/vespa/eval/eval/tensor_nodes.h
@@ -17,18 +17,18 @@ namespace nodes {
class TensorMap : public Node {
private:
- Node_UP _child;
- Function _lambda;
+ Node_UP _child;
+ std::shared_ptr<Function const> _lambda;
public:
- TensorMap(Node_UP child, Function lambda)
+ TensorMap(Node_UP child, std::shared_ptr<Function const> lambda)
: _child(std::move(child)), _lambda(std::move(lambda)) {}
- const Function &lambda() const { return _lambda; }
+ const Function &lambda() const { return *_lambda; }
vespalib::string dump(DumpContext &ctx) const override {
vespalib::string str;
str += "map(";
str += _child->dump(ctx);
str += ",";
- str += _lambda.dump_as_lambda();
+ str += _lambda->dump_as_lambda();
str += ")";
return str;
}
@@ -46,13 +46,13 @@ public:
class TensorJoin : public Node {
private:
- Node_UP _lhs;
- Node_UP _rhs;
- Function _lambda;
+ Node_UP _lhs;
+ Node_UP _rhs;
+ std::shared_ptr<Function const> _lambda;
public:
- TensorJoin(Node_UP lhs, Node_UP rhs, Function lambda)
+ TensorJoin(Node_UP lhs, Node_UP rhs, std::shared_ptr<Function const> lambda)
: _lhs(std::move(lhs)), _rhs(std::move(rhs)), _lambda(std::move(lambda)) {}
- const Function &lambda() const { return _lambda; }
+ const Function &lambda() const { return *_lambda; }
vespalib::string dump(DumpContext &ctx) const override {
vespalib::string str;
str += "join(";
@@ -60,7 +60,7 @@ public:
str += ",";
str += _rhs->dump(ctx);
str += ",";
- str += _lambda.dump_as_lambda();
+ str += _lambda->dump_as_lambda();
str += ")";
return str;
}
diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.cpp b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
index 30e55d0418b..e578b28da18 100644
--- a/eval/src/vespa/eval/eval/test/eval_fixture.cpp
+++ b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
@@ -91,14 +91,14 @@ EvalFixture::EvalFixture(const TensorEngine &engine,
: _engine(engine),
_stash(),
_function(Function::parse(expr)),
- _node_types(get_types(_function, param_repo)),
- _mutable_set(get_mutable(_function, param_repo)),
- _plain_tensor_function(make_tensor_function(_engine, _function.root(), _node_types, _stash)),
+ _node_types(get_types(*_function, param_repo)),
+ _mutable_set(get_mutable(*_function, param_repo)),
+ _plain_tensor_function(make_tensor_function(_engine, _function->root(), _node_types, _stash)),
_patched_tensor_function(maybe_patch(allow_mutable, _plain_tensor_function, _mutable_set, _stash)),
_tensor_function(optimized ? _engine.optimize(_patched_tensor_function, _stash) : _patched_tensor_function),
_ifun(_engine, _tensor_function),
_ictx(_ifun),
- _param_values(make_params(_engine, _function, param_repo)),
+ _param_values(make_params(_engine, *_function, param_repo)),
_params(get_refs(_param_values)),
_result(_engine.to_spec(_ifun.eval(_ictx, _params)))
{
diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.h b/eval/src/vespa/eval/eval/test/eval_fixture.h
index 9c793c01861..48f6a7e5d2e 100644
--- a/eval/src/vespa/eval/eval/test/eval_fixture.h
+++ b/eval/src/vespa/eval/eval/test/eval_fixture.h
@@ -41,19 +41,19 @@ public:
};
private:
- const TensorEngine &_engine;
- Stash _stash;
- Function _function;
- NodeTypes _node_types;
- std::set<size_t> _mutable_set;
- const TensorFunction &_plain_tensor_function;
- const TensorFunction &_patched_tensor_function;
- const TensorFunction &_tensor_function;
- InterpretedFunction _ifun;
- InterpretedFunction::Context _ictx;
- std::vector<Value::UP> _param_values;
- SimpleObjectParams _params;
- TensorSpec _result;
+ const TensorEngine &_engine;
+ Stash _stash;
+ std::shared_ptr<Function const> _function;
+ NodeTypes _node_types;
+ std::set<size_t> _mutable_set;
+ const TensorFunction &_plain_tensor_function;
+ const TensorFunction &_patched_tensor_function;
+ const TensorFunction &_tensor_function;
+ InterpretedFunction _ifun;
+ InterpretedFunction::Context _ictx;
+ std::vector<Value::UP> _param_values;
+ SimpleObjectParams _params;
+ TensorSpec _result;
template <typename T>
void find_all(const TensorFunction &node, std::vector<const T *> &list) {
diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
index 3af7ea0a1f3..9242b19310d 100644
--- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
+++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
@@ -119,12 +119,12 @@ struct Expr_V : Eval {
const vespalib::string &expr;
Expr_V(const vespalib::string &expr_in) : expr(expr_in) {}
Result eval(const TensorEngine &engine) const override {
- Function fun = Function::parse(expr);
- NodeTypes types(fun, {});
- InterpretedFunction ifun(engine, fun, types);
+ auto fun = Function::parse(expr);
+ NodeTypes types(*fun, {});
+ InterpretedFunction ifun(engine, *fun, types);
InterpretedFunction::Context ctx(ifun);
SimpleObjectParams params({});
- return Result(engine, check_type(ifun.eval(ctx, params), types.get_type(fun.root())));
+ return Result(engine, check_type(ifun.eval(ctx, params), types.get_type(fun->root())));
}
};
@@ -133,14 +133,14 @@ struct Expr_T : Eval {
const vespalib::string &expr;
Expr_T(const vespalib::string &expr_in) : expr(expr_in) {}
Result eval(const TensorEngine &engine, const TensorSpec &a) const override {
- Function fun = Function::parse(expr);
+ auto fun = Function::parse(expr);
auto a_type = ValueType::from_spec(a.type());
- NodeTypes types(fun, {a_type});
- InterpretedFunction ifun(engine, fun, types);
+ NodeTypes types(*fun, {a_type});
+ InterpretedFunction ifun(engine, *fun, types);
InterpretedFunction::Context ctx(ifun);
Value::UP va = engine.from_spec(a);
SimpleObjectParams params({*va});
- return Result(engine, check_type(ifun.eval(ctx, params), types.get_type(fun.root())));
+ return Result(engine, check_type(ifun.eval(ctx, params), types.get_type(fun->root())));
}
};
@@ -149,16 +149,16 @@ struct Expr_TT : Eval {
vespalib::string expr;
Expr_TT(const vespalib::string &expr_in) : expr(expr_in) {}
Result eval(const TensorEngine &engine, const TensorSpec &a, const TensorSpec &b) const override {
- Function fun = Function::parse(expr);
+ auto fun = Function::parse(expr);
auto a_type = ValueType::from_spec(a.type());
auto b_type = ValueType::from_spec(b.type());
- NodeTypes types(fun, {a_type, b_type});
- InterpretedFunction ifun(engine, fun, types);
+ NodeTypes types(*fun, {a_type, b_type});
+ InterpretedFunction ifun(engine, *fun, types);
InterpretedFunction::Context ctx(ifun);
Value::UP va = engine.from_spec(a);
Value::UP vb = engine.from_spec(b);
SimpleObjectParams params({*va,*vb});
- return Result(engine, check_type(ifun.eval(ctx, params), types.get_type(fun.root())));
+ return Result(engine, check_type(ifun.eval(ctx, params), types.get_type(fun->root())));
}
};