summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2017-10-26 09:55:28 +0000
committerHåvard Pettersen <havardpe@oath.com>2017-10-26 09:55:28 +0000
commit5173775ac89edc765835c4fc8ac6a77383fc4c6e (patch)
tree4c435bc22bfb77c89079a0cabd609ae8789ccd1f /eval
parent5ae6c1e24bbe54daa705d646774d4d90a48dd971 (diff)
test both reference and production tensor engines
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp47
1 files changed, 31 insertions, 16 deletions
diff --git a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
index 4497c5f5e70..9aeb28a17b1 100644
--- a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
+++ b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
@@ -5,6 +5,7 @@
#include <vespa/eval/eval/interpreted_function.h>
#include <vespa/eval/eval/test/eval_spec.h>
#include <vespa/eval/eval/basic_nodes.h>
+#include <vespa/eval/tensor/default_tensor_engine.h>
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/util/stash.h>
#include <vespa/vespalib/test/insertion_operators.h>
@@ -12,6 +13,7 @@
using namespace vespalib::eval;
using vespalib::Stash;
+using vespalib::tensor::DefaultTensorEngine;
//-----------------------------------------------------------------------------
@@ -20,6 +22,7 @@ struct MyEvalTest : test::EvalSpec::EvalTest {
size_t fail_cnt = 0;
bool print_pass = false;
bool print_fail = false;
+
virtual void next_expression(const std::vector<vespalib::string> &param_names,
const vespalib::string &expression) override
{
@@ -35,6 +38,7 @@ struct MyEvalTest : test::EvalSpec::EvalTest {
++fail_cnt;
}
}
+
virtual void handle_case(const std::vector<vespalib::string> &param_names,
const std::vector<double> &param_values,
const vespalib::string &expression,
@@ -45,23 +49,34 @@ struct MyEvalTest : test::EvalSpec::EvalTest {
bool is_supported = true;
bool has_issues = InterpretedFunction::detect_issues(function);
if (is_supported && !has_issues) {
- InterpretedFunction ifun(SimpleTensorEngine::ref(), function, NodeTypes());
- ASSERT_EQUAL(ifun.num_params(), param_values.size());
- InterpretedFunction::Context ictx(ifun);
+ vespalib::string desc = as_string(param_names, param_values, expression);
InterpretedFunction::SimpleParams params(param_values);
- const Value &result_value = ifun.eval(ictx, params);
- double result = result_value.as_double();
- if (result_value.is_double() && is_same(expected_result, result)) {
- print_pass && fprintf(stderr, "verifying: %s -> %g ... PASS\n",
- as_string(param_names, param_values, expression).c_str(),
- expected_result);
- ++pass_cnt;
- } else {
- print_fail && fprintf(stderr, "verifying: %s -> %g ... FAIL: got %g\n",
- as_string(param_names, param_values, expression).c_str(),
- expected_result, result);
- ++fail_cnt;
- }
+ verify_result(SimpleTensorEngine::ref(), function, "[simple] "+desc, params, expected_result);
+ verify_result(DefaultTensorEngine::ref(), function, " [prod] "+desc, params, expected_result);
+ }
+ }
+
+ void verify_result(const TensorEngine& engine,
+ const Function &function,
+ const vespalib::string &description,
+ const InterpretedFunction::SimpleParams &params,
+ double expected_result)
+ {
+ InterpretedFunction ifun(engine, function, NodeTypes());
+ ASSERT_EQUAL(ifun.num_params(), params.params.size());
+ InterpretedFunction::Context ictx(ifun);
+ const Value &result_value = ifun.eval(ictx, params);
+ double result = result_value.as_double();
+ if (result_value.is_double() && is_same(expected_result, result)) {
+ print_pass && fprintf(stderr, "verifying: %s -> %g ... PASS\n",
+ description.c_str(),
+ expected_result);
+ ++pass_cnt;
+ } else {
+ print_fail && fprintf(stderr, "verifying: %s -> %g ... FAIL: got %g\n",
+ description.c_str(),
+ expected_result, result);
+ ++fail_cnt;
}
}
};