diff options
author | Håvard Pettersen <havardpe@oath.com> | 2018-01-18 12:58:57 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2018-01-18 12:58:57 +0000 |
commit | 15b416446c7a6e674a14b5de0e0243fdab83f340 (patch) | |
tree | 6034dbad7cc484d9f968e97c8ee97e4dd8919894 /eval | |
parent | 346562c8061366326144e5dac22992dcdd5a59ce (diff) |
run all tests from test_spec as tensor functions
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp | 48 |
1 files changed, 32 insertions, 16 deletions
diff --git a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp index 91eb741f334..a4222df6e00 100644 --- a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp +++ b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp @@ -52,13 +52,27 @@ struct MyEvalTest : test::EvalSpec::EvalTest { if (is_supported && !has_issues) { vespalib::string desc = as_string(param_names, param_values, expression); SimpleParams params(param_values); - verify_result(SimpleTensorEngine::ref(), function, false, "[untyped simple] "+desc, params, expected_result); - verify_result(DefaultTensorEngine::ref(), function, false, "[untyped prod] "+desc, params, expected_result); - verify_result(DefaultTensorEngine::ref(), function, true, "[typed prod] "+desc, params, expected_result); + verify_result(SimpleTensorEngine::ref(), function, false, "[untyped simple] "+desc, params, expected_result); + verify_result(DefaultTensorEngine::ref(), function, false, "[untyped prod] "+desc, params, expected_result); + verify_result(DefaultTensorEngine::ref(), function, true, "[typed prod] "+desc, params, expected_result); + verify_tensor_function(DefaultTensorEngine::ref(), function, "[tensor function]"+desc, params, expected_result); } } - void verify_result(const TensorEngine& engine, + void report_result(bool is_double, double result, double expect, const vespalib::string &desc) + { + if (is_double && is_same(expect, result)) { + print_pass && fprintf(stderr, "verifying: %s -> %g ... PASS\n", + desc.c_str(), expect); + ++pass_cnt; + } else { + print_fail && fprintf(stderr, "verifying: %s -> %g ... FAIL: got %g\n", + desc.c_str(), expect, result); + ++fail_cnt; + } + } + + void verify_result(const TensorEngine &engine, const Function &function, bool typed, const vespalib::string &description, @@ -72,18 +86,20 @@ struct MyEvalTest : test::EvalSpec::EvalTest { ASSERT_EQUAL(ifun.num_params(), params.params.size()); InterpretedFunction::Context ictx(ifun); const Value &result_value = ifun.eval(ictx, params); - double result = result_value.as_double(); - if (result_value.is_double() && is_same(expected_result, result)) { - print_pass && fprintf(stderr, "verifying: %s -> %g ... PASS\n", - description.c_str(), - expected_result); - ++pass_cnt; - } else { - print_fail && fprintf(stderr, "verifying: %s -> %g ... FAIL: got %g\n", - description.c_str(), - expected_result, result); - ++fail_cnt; - } + report_result(result_value.is_double(), result_value.as_double(), expected_result, description); + } + + void verify_tensor_function(const TensorEngine &engine, + const Function &function, + const vespalib::string &description, + const SimpleParams ¶ms, + double expected_result) + { + Stash stash; + NodeTypes node_types = NodeTypes(function, std::vector<ValueType>(params.params.size(), ValueType::double_type())); + const auto &tfun = make_tensor_function(engine, function.root(), node_types, stash); + const Value &result_value = tfun.eval(params, stash); + report_result(result_value.is_double(), result_value.as_double(), expected_result, description); } }; |