summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2018-01-18 12:58:57 +0000
committerHåvard Pettersen <havardpe@oath.com>2018-01-18 12:58:57 +0000
commit15b416446c7a6e674a14b5de0e0243fdab83f340 (patch)
tree6034dbad7cc484d9f968e97c8ee97e4dd8919894 /eval
parent346562c8061366326144e5dac22992dcdd5a59ce (diff)
run all tests from test_spec as tensor functions
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp48
1 files changed, 32 insertions, 16 deletions
diff --git a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
index 91eb741f334..a4222df6e00 100644
--- a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
+++ b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
@@ -52,13 +52,27 @@ struct MyEvalTest : test::EvalSpec::EvalTest {
if (is_supported && !has_issues) {
vespalib::string desc = as_string(param_names, param_values, expression);
SimpleParams params(param_values);
- verify_result(SimpleTensorEngine::ref(), function, false, "[untyped simple] "+desc, params, expected_result);
- verify_result(DefaultTensorEngine::ref(), function, false, "[untyped prod] "+desc, params, expected_result);
- verify_result(DefaultTensorEngine::ref(), function, true, "[typed prod] "+desc, params, expected_result);
+ verify_result(SimpleTensorEngine::ref(), function, false, "[untyped simple] "+desc, params, expected_result);
+ verify_result(DefaultTensorEngine::ref(), function, false, "[untyped prod] "+desc, params, expected_result);
+ verify_result(DefaultTensorEngine::ref(), function, true, "[typed prod] "+desc, params, expected_result);
+ verify_tensor_function(DefaultTensorEngine::ref(), function, "[tensor function]"+desc, params, expected_result);
}
}
- void verify_result(const TensorEngine& engine,
+ void report_result(bool is_double, double result, double expect, const vespalib::string &desc)
+ {
+ if (is_double && is_same(expect, result)) {
+ print_pass && fprintf(stderr, "verifying: %s -> %g ... PASS\n",
+ desc.c_str(), expect);
+ ++pass_cnt;
+ } else {
+ print_fail && fprintf(stderr, "verifying: %s -> %g ... FAIL: got %g\n",
+ desc.c_str(), expect, result);
+ ++fail_cnt;
+ }
+ }
+
+ void verify_result(const TensorEngine &engine,
const Function &function,
bool typed,
const vespalib::string &description,
@@ -72,18 +86,20 @@ struct MyEvalTest : test::EvalSpec::EvalTest {
ASSERT_EQUAL(ifun.num_params(), params.params.size());
InterpretedFunction::Context ictx(ifun);
const Value &result_value = ifun.eval(ictx, params);
- double result = result_value.as_double();
- if (result_value.is_double() && is_same(expected_result, result)) {
- print_pass && fprintf(stderr, "verifying: %s -> %g ... PASS\n",
- description.c_str(),
- expected_result);
- ++pass_cnt;
- } else {
- print_fail && fprintf(stderr, "verifying: %s -> %g ... FAIL: got %g\n",
- description.c_str(),
- expected_result, result);
- ++fail_cnt;
- }
+ report_result(result_value.is_double(), result_value.as_double(), expected_result, description);
+ }
+
+ void verify_tensor_function(const TensorEngine &engine,
+ const Function &function,
+ const vespalib::string &description,
+ const SimpleParams &params,
+ double expected_result)
+ {
+ Stash stash;
+ NodeTypes node_types = NodeTypes(function, std::vector<ValueType>(params.params.size(), ValueType::double_type()));
+ const auto &tfun = make_tensor_function(engine, function.root(), node_types, stash);
+ const Value &result_value = tfun.eval(params, stash);
+ report_result(result_value.is_double(), result_value.as_double(), expected_result, description);
}
};