summaryrefslogtreecommitdiffstats
path: root/eval/src/tests
diff options
context:
space:
mode:
authorHaavard <havardpe@yahoo-inc.com>2017-02-09 14:44:10 +0000
committerHaavard <havardpe@yahoo-inc.com>2017-02-10 16:07:47 +0000
commita9e64013b4db64cb6ec86be6dcc0076282ab8858 (patch)
treeca16fa7cb4d1efe5a4ccbf8380e235bf15b8baa6 /eval/src/tests
parent461694eed494a7dc0f365725439beee2089eaec5 (diff)
wire in immediate evaluation of new syntax
Diffstat (limited to 'eval/src/tests')
-rw-r--r--eval/src/tests/eval/function_speed/function_speed_test.cpp4
-rw-r--r--eval/src/tests/eval/gbdt/gbdt_test.cpp2
-rw-r--r--eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp69
-rw-r--r--eval/src/tests/tensor/tensor_performance/tensor_performance_test.cpp4
4 files changed, 34 insertions, 45 deletions
diff --git a/eval/src/tests/eval/function_speed/function_speed_test.cpp b/eval/src/tests/eval/function_speed/function_speed_test.cpp
index 41463f0ef5b..bdb93daec19 100644
--- a/eval/src/tests/eval/function_speed/function_speed_test.cpp
+++ b/eval/src/tests/eval/function_speed/function_speed_test.cpp
@@ -21,7 +21,7 @@ double gcc_function(double p, double o, double q, double f, double w) {
return (0.35*p + 0.15*o + 0.30*q + 0.20*f) * w;
}
-InterpretedFunction::Context icontext;
+InterpretedFunction::Context icontext(interpreted_function);
double interpret_function(double p, double o, double q, double f, double w) {
icontext.clear_params();
@@ -52,7 +52,7 @@ double big_gcc_function(double p, double o, double q, double f, double w) {
(0.35*p + 0.15*o + 0.30*q + 0.20*f) * w;
}
-InterpretedFunction::Context big_icontext;
+InterpretedFunction::Context big_icontext(big_interpreted_function);
double big_interpret_function(double p, double o, double q, double f, double w) {
big_icontext.clear_params();
diff --git a/eval/src/tests/eval/gbdt/gbdt_test.cpp b/eval/src/tests/eval/gbdt/gbdt_test.cpp
index 58e4fca2d12..12e79941b44 100644
--- a/eval/src/tests/eval/gbdt/gbdt_test.cpp
+++ b/eval/src/tests/eval/gbdt/gbdt_test.cpp
@@ -17,7 +17,7 @@ using namespace vespalib::eval::gbdt;
double eval_double(const Function &function, const std::vector<double> &params) {
InterpretedFunction ifun(SimpleTensorEngine::ref(), function, NodeTypes());
- InterpretedFunction::Context ctx;
+ InterpretedFunction::Context ctx(ifun);
for (double param: params) {
ctx.add_param(param);
}
diff --git a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
index d39427ac232..4a0051303bb 100644
--- a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
+++ b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
@@ -15,29 +15,6 @@ using vespalib::Stash;
//-----------------------------------------------------------------------------
-std::vector<vespalib::string> unsupported = {
- "map(",
- "join(",
- "reduce(",
- "rename(",
- "tensor(",
- "concat("
-};
-
-bool is_unsupported(const vespalib::string &expression) {
- if (expression == "reduce(a,sum)") {
- return false;
- }
- for (const auto &prefix: unsupported) {
- if (starts_with(expression, prefix)) {
- return true;
- }
- }
- return false;
-}
-
-//-----------------------------------------------------------------------------
-
struct MyEvalTest : test::EvalSpec::EvalTest {
size_t pass_cnt = 0;
size_t fail_cnt = 0;
@@ -48,7 +25,7 @@ struct MyEvalTest : test::EvalSpec::EvalTest {
{
Function function = Function::parse(param_names, expression);
ASSERT_TRUE(!function.has_error());
- bool is_supported = !is_unsupported(expression);
+ bool is_supported = true;
bool has_issues = InterpretedFunction::detect_issues(function);
if (is_supported == has_issues) {
const char *supported_str = is_supported ? "supported" : "not supported";
@@ -65,12 +42,12 @@ struct MyEvalTest : test::EvalSpec::EvalTest {
{
Function function = Function::parse(param_names, expression);
ASSERT_TRUE(!function.has_error());
- bool is_supported = !is_unsupported(expression);
+ bool is_supported = true;
bool has_issues = InterpretedFunction::detect_issues(function);
if (is_supported && !has_issues) {
InterpretedFunction ifun(SimpleTensorEngine::ref(), function, NodeTypes());
ASSERT_EQUAL(ifun.num_params(), param_values.size());
- InterpretedFunction::Context ictx;
+ InterpretedFunction::Context ictx(ifun);
for (double param: param_values) {
ictx.add_param(param);
}
@@ -106,7 +83,7 @@ TEST("require that invalid function evaluates to a error") {
Function function = Function::parse(params, "x & y");
EXPECT_TRUE(function.has_error());
InterpretedFunction ifun(SimpleTensorEngine::ref(), function, NodeTypes());
- InterpretedFunction::Context ctx;
+ InterpretedFunction::Context ctx(ifun);
ctx.add_param(1);
ctx.add_param(2);
ctx.add_param(3);
@@ -121,7 +98,7 @@ TEST("require that invalid function evaluates to a error") {
size_t count_ifs(const vespalib::string &expr, std::initializer_list<double> params_in) {
Function fun = Function::parse(expr);
InterpretedFunction ifun(SimpleTensorEngine::ref(), fun, NodeTypes());
- InterpretedFunction::Context ctx;
+ InterpretedFunction::Context ctx(ifun);
for (double param: params_in) {
ctx.add_param(param);
}
@@ -147,7 +124,7 @@ TEST("require that interpreted function instructions have expected size") {
TEST("require that basic addition works") {
Function function = Function::parse("a+10");
InterpretedFunction interpreted(SimpleTensorEngine::ref(), function, NodeTypes());
- InterpretedFunction::Context ctx;
+ InterpretedFunction::Context ctx(interpreted);
ctx.add_param(20);
EXPECT_EQUAL(interpreted.eval(ctx).as_double(), 30.0);
ctx.clear_params();
@@ -165,7 +142,7 @@ TEST("require that dot product like expression is not optimized for unknown type
double expect = (2.0 * 3.0);
InterpretedFunction interpreted(engine, function, NodeTypes());
EXPECT_EQUAL(4u, interpreted.program_size());
- InterpretedFunction::Context ctx;
+ InterpretedFunction::Context ctx(interpreted);
ctx.add_param(a);
ctx.add_param(b);
const Value &result = interpreted.eval(ctx);
@@ -188,7 +165,7 @@ TEST("require that dot product works with tensor function") {
NodeTypes types(function, {ValueType::from_spec(a.type()), ValueType::from_spec(a.type())});
InterpretedFunction interpreted(engine, function, types);
EXPECT_EQUAL(1u, interpreted.program_size());
- InterpretedFunction::Context ctx;
+ InterpretedFunction::Context ctx(interpreted);
TensorValue va(engine.create(a));
TensorValue vb(engine.create(b));
ctx.add_param(va);
@@ -219,7 +196,7 @@ TEST("require that matrix multiplication works with tensor function") {
NodeTypes types(function, {ValueType::from_spec(a.type()), ValueType::from_spec(a.type())});
InterpretedFunction interpreted(engine, function, types);
EXPECT_EQUAL(1u, interpreted.program_size());
- InterpretedFunction::Context ctx;
+ InterpretedFunction::Context ctx(interpreted);
TensorValue va(engine.create(a));
TensorValue vb(engine.create(b));
ctx.add_param(va);
@@ -231,15 +208,27 @@ TEST("require that matrix multiplication works with tensor function") {
//-----------------------------------------------------------------------------
-TEST("require function issues can be detected") {
- auto simple = Function::parse("a+b");
- auto complex = Function::parse("join(a,b,f(a,b)(a+b))");
- EXPECT_FALSE(simple.has_error());
- EXPECT_FALSE(complex.has_error());
- EXPECT_FALSE(InterpretedFunction::detect_issues(simple));
- EXPECT_TRUE(InterpretedFunction::detect_issues(complex));
+TEST("require that functions with non-compilable lambdas cannot be interpreted") {
+ auto good_map = Function::parse("map(a,f(x)(x+1))");
+ auto good_join = Function::parse("join(a,b,f(x,y)(x+y))");
+ auto good_tensor = Function::parse("tensor(a[10],b[10])(a+b)");
+ auto bad_map = Function::parse("map(a,f(x)(map(x,f(i)(i+1))))");
+ auto bad_join = Function::parse("join(a,b,f(x,y)(join(x,y,f(i,j)(i+j))))");
+ auto bad_tensor = Function::parse("tensor(a[10],b[10])(join(a,b,f(i,j)(i+j)))");
+ for (const Function *good: {&good_map, &good_join, &good_tensor}) {
+ if (!EXPECT_TRUE(!good->has_error())) {
+ fprintf(stderr, "parse error: %s\n", good->get_error().c_str());
+ }
+ EXPECT_TRUE(!InterpretedFunction::detect_issues(*good));
+ }
+ for (const Function *bad: {&bad_map, &bad_join, &bad_tensor}) {
+ if (!EXPECT_TRUE(!bad->has_error())) {
+ fprintf(stderr, "parse error: %s\n", bad->get_error().c_str());
+ }
+ EXPECT_TRUE(InterpretedFunction::detect_issues(*bad));
+ }
std::cerr << "Example function issues:" << std::endl
- << InterpretedFunction::detect_issues(complex).list
+ << InterpretedFunction::detect_issues(bad_tensor).list
<< std::endl;
}
diff --git a/eval/src/tests/tensor/tensor_performance/tensor_performance_test.cpp b/eval/src/tests/tensor/tensor_performance/tensor_performance_test.cpp
index 105eb955413..64bec6d1186 100644
--- a/eval/src/tests/tensor/tensor_performance/tensor_performance_test.cpp
+++ b/eval/src/tests/tensor/tensor_performance/tensor_performance_test.cpp
@@ -69,7 +69,7 @@ double calculate_expression(const vespalib::string &expression, const Params &pa
const Function function = Function::parse(expression);
const NodeTypes types(function, extract_param_types(function, params));
const InterpretedFunction interpreted(tensor::DefaultTensorEngine::ref(), function, types);
- InterpretedFunction::Context context;
+ InterpretedFunction::Context context(interpreted);
inject_params(function, params, context);
const Value &result = interpreted.eval(context);
EXPECT_TRUE(result.is_double());
@@ -83,7 +83,7 @@ double benchmark_expression_us(const vespalib::string &expression, const Params
const Function function = Function::parse(expression);
const NodeTypes types(function, extract_param_types(function, params));
const InterpretedFunction interpreted(tensor::DefaultTensorEngine::ref(), function, types);
- InterpretedFunction::Context context;
+ InterpretedFunction::Context context(interpreted);
inject_params(function, params, context);
auto ranking = [&](){ interpreted.eval(context); };
auto baseline = [&](){ dummy_ranking(context); };