diff options
author | Håvard Pettersen <havardpe@oath.com> | 2017-10-26 09:55:28 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2017-10-26 09:55:28 +0000 |
commit | 5173775ac89edc765835c4fc8ac6a77383fc4c6e (patch) | |
tree | 4c435bc22bfb77c89079a0cabd609ae8789ccd1f /eval | |
parent | 5ae6c1e24bbe54daa705d646774d4d90a48dd971 (diff) |
test both reference and production tensor engines
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp | 47 |
1 files changed, 31 insertions, 16 deletions
diff --git a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp index 4497c5f5e70..9aeb28a17b1 100644 --- a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp +++ b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp @@ -5,6 +5,7 @@ #include <vespa/eval/eval/interpreted_function.h> #include <vespa/eval/eval/test/eval_spec.h> #include <vespa/eval/eval/basic_nodes.h> +#include <vespa/eval/tensor/default_tensor_engine.h> #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/util/stash.h> #include <vespa/vespalib/test/insertion_operators.h> @@ -12,6 +13,7 @@ using namespace vespalib::eval; using vespalib::Stash; +using vespalib::tensor::DefaultTensorEngine; //----------------------------------------------------------------------------- @@ -20,6 +22,7 @@ struct MyEvalTest : test::EvalSpec::EvalTest { size_t fail_cnt = 0; bool print_pass = false; bool print_fail = false; + virtual void next_expression(const std::vector<vespalib::string> ¶m_names, const vespalib::string &expression) override { @@ -35,6 +38,7 @@ struct MyEvalTest : test::EvalSpec::EvalTest { ++fail_cnt; } } + virtual void handle_case(const std::vector<vespalib::string> ¶m_names, const std::vector<double> ¶m_values, const vespalib::string &expression, @@ -45,23 +49,34 @@ struct MyEvalTest : test::EvalSpec::EvalTest { bool is_supported = true; bool has_issues = InterpretedFunction::detect_issues(function); if (is_supported && !has_issues) { - InterpretedFunction ifun(SimpleTensorEngine::ref(), function, NodeTypes()); - ASSERT_EQUAL(ifun.num_params(), param_values.size()); - InterpretedFunction::Context ictx(ifun); + vespalib::string desc = as_string(param_names, param_values, expression); InterpretedFunction::SimpleParams params(param_values); - const Value &result_value = ifun.eval(ictx, params); - double result = result_value.as_double(); - if (result_value.is_double() && is_same(expected_result, result)) { - print_pass && fprintf(stderr, "verifying: %s -> %g ... PASS\n", - as_string(param_names, param_values, expression).c_str(), - expected_result); - ++pass_cnt; - } else { - print_fail && fprintf(stderr, "verifying: %s -> %g ... FAIL: got %g\n", - as_string(param_names, param_values, expression).c_str(), - expected_result, result); - ++fail_cnt; - } + verify_result(SimpleTensorEngine::ref(), function, "[simple] "+desc, params, expected_result); + verify_result(DefaultTensorEngine::ref(), function, " [prod] "+desc, params, expected_result); + } + } + + void verify_result(const TensorEngine& engine, + const Function &function, + const vespalib::string &description, + const InterpretedFunction::SimpleParams ¶ms, + double expected_result) + { + InterpretedFunction ifun(engine, function, NodeTypes()); + ASSERT_EQUAL(ifun.num_params(), params.params.size()); + InterpretedFunction::Context ictx(ifun); + const Value &result_value = ifun.eval(ictx, params); + double result = result_value.as_double(); + if (result_value.is_double() && is_same(expected_result, result)) { + print_pass && fprintf(stderr, "verifying: %s -> %g ... PASS\n", + description.c_str(), + expected_result); + ++pass_cnt; + } else { + print_fail && fprintf(stderr, "verifying: %s -> %g ... FAIL: got %g\n", + description.c_str(), + expected_result, result); + ++fail_cnt; } } }; |