summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2018-02-06 13:23:34 +0000
committerHåvard Pettersen <havardpe@oath.com>2018-02-08 15:37:59 +0000
commit3e79340de4405f58ead08845bc07fbd50e9253ae (patch)
treea3963c011c35f7beb22ea3d5d40b65f112622f16 /eval
parent8fed2123d8bc62ceeddb0660f0c5a70719f8fff9 (diff)
add code to help testing
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/eval/test/CMakeLists.txt1
-rw-r--r--eval/src/vespa/eval/eval/test/eval_fixture.cpp72
-rw-r--r--eval/src/vespa/eval/eval/test/eval_fixture.h84
3 files changed, 157 insertions, 0 deletions
diff --git a/eval/src/vespa/eval/eval/test/CMakeLists.txt b/eval/src/vespa/eval/eval/test/CMakeLists.txt
index f27c689e8b6..8b4f7c4f93b 100644
--- a/eval/src/vespa/eval/eval/test/CMakeLists.txt
+++ b/eval/src/vespa/eval/eval/test/CMakeLists.txt
@@ -1,6 +1,7 @@
# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
vespa_add_library(eval_eval_test OBJECT
SOURCES
+ eval_fixture.cpp
eval_spec.cpp
tensor_conformance.cpp
test_io.cpp
diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.cpp b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
new file mode 100644
index 00000000000..b7e3764f6ce
--- /dev/null
+++ b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
@@ -0,0 +1,72 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/vespalib/testkit/test_kit.h>
+#include "eval_fixture.h"
+
+namespace vespalib::eval::test {
+
+using ParamRepo = EvalFixture::ParamRepo;
+
+namespace {
+
+NodeTypes get_types(const Function &function, const ParamRepo &param_repo) {
+ std::vector<ValueType> param_types;
+ for (size_t i = 0; i < function.num_params(); ++i) {
+ auto pos = param_repo.map.find(function.param_name(i));
+ ASSERT_TRUE(pos != param_repo.map.end());
+ param_types.push_back(ValueType::from_spec(pos->second.type));
+ ASSERT_TRUE(!param_types.back().is_error());
+ }
+ return NodeTypes(function, param_types);
+}
+
+const TensorFunction &make_tfun(bool optimized, const TensorEngine &engine, const Function &function,
+ const NodeTypes &node_types, Stash &stash)
+{
+ const TensorFunction &plain_fun = make_tensor_function(engine, function.root(), node_types, stash);
+ return optimized ? engine.optimize(plain_fun, stash) : plain_fun;
+}
+
+std::vector<Value::UP> make_params(const TensorEngine &engine, const Function &function,
+ const ParamRepo &param_repo)
+{
+ std::vector<Value::UP> result;
+ for (size_t i = 0; i < function.num_params(); ++i) {
+ auto pos = param_repo.map.find(function.param_name(i));
+ ASSERT_TRUE(pos != param_repo.map.end());
+ result.push_back(engine.from_spec(pos->second.value));
+ ASSERT_TRUE(!result.back()->type().is_abstract());
+ }
+ return result;
+}
+
+std::vector<Value::CREF> get_refs(const std::vector<Value::UP> &values) {
+ std::vector<Value::CREF> result;
+ for (const auto &value: values) {
+ result.emplace_back(*value);
+ }
+ return result;
+}
+
+} // namespace vespalib::eval::test
+
+EvalFixture::EvalFixture(const TensorEngine &engine,
+ const vespalib::string &expr,
+ const ParamRepo &param_repo,
+ bool optimized)
+ : _engine(engine),
+ _stash(),
+ _function(Function::parse(expr)),
+ _node_types(get_types(_function, param_repo)),
+ _tensor_function(make_tfun(optimized, _engine, _function, _node_types, _stash)),
+ _ifun(_engine, _tensor_function),
+ _ictx(_ifun),
+ _param_values(make_params(_engine, _function, param_repo)),
+ _params(get_refs(_param_values)),
+ _result(_engine.to_spec(_ifun.eval(_ictx, _params)))
+{
+ auto result_type = ValueType::from_spec(_result.type());
+ ASSERT_TRUE(!result_type.is_error());
+}
+
+} // namespace vespalib::eval::test
diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.h b/eval/src/vespa/eval/eval/test/eval_fixture.h
new file mode 100644
index 00000000000..1f864e980cc
--- /dev/null
+++ b/eval/src/vespa/eval/eval/test/eval_fixture.h
@@ -0,0 +1,84 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <vespa/eval/eval/function.h>
+#include <vespa/eval/eval/tensor_spec.h>
+#include <vespa/eval/eval/tensor_function.h>
+#include <vespa/eval/eval/interpreted_function.h>
+#include <vespa/eval/eval/simple_tensor_engine.h>
+#include <vespa/eval/tensor/default_tensor_engine.h>
+#include <vespa/vespalib/util/stash.h>
+
+namespace vespalib::eval::test {
+
+class EvalFixture
+{
+public:
+ struct Param {
+ TensorSpec value; // actual parameter value
+ vespalib::string type; // pre-defined type (could be abstract)
+ Param(TensorSpec value_in)
+ : value(std::move(value_in)), type(value.type()) {}
+ Param(TensorSpec value_in, const vespalib::string &type_in)
+ : value(std::move(value_in)), type(type_in) {}
+ ~Param() {}
+ };
+
+ struct ParamRepo {
+ std::map<vespalib::string,Param> map;
+ ParamRepo() : map() {}
+ ParamRepo &add(const vespalib::string &name, TensorSpec value_in) {
+ map.insert_or_assign(name, Param(std::move(value_in)));
+ return *this;
+ }
+ ParamRepo &add(const vespalib::string &name, TensorSpec value_in, const vespalib::string &type_in) {
+ map.insert_or_assign(name, Param(std::move(value_in), type_in));
+ return *this;
+ }
+ ~ParamRepo() {}
+ };
+
+private:
+ const TensorEngine &_engine;
+ Stash _stash;
+ Function _function;
+ NodeTypes _node_types;
+ const TensorFunction &_tensor_function;
+ InterpretedFunction _ifun;
+ InterpretedFunction::Context _ictx;
+ std::vector<Value::UP> _param_values;
+ SimpleObjectParams _params;
+ TensorSpec _result;
+
+ template <typename T>
+ void find_all(const TensorFunction &node, std::vector<const T *> &list) {
+ if (auto self = as<T>(node)) {
+ list.push_back(self);
+ }
+ std::vector<TensorFunction::Child::CREF> children;
+ node.push_children(children);
+ for (const auto &child: children) {
+ find_all(child.get().get(), list);
+ }
+ }
+
+public:
+ EvalFixture(const TensorEngine &engine, const vespalib::string &expr, const ParamRepo &param_repo, bool optimized);
+ ~EvalFixture() {}
+ template <typename T>
+ std::vector<const T *> find_all() {
+ std::vector<const T *> list;
+ find_all(_tensor_function, list);
+ return list;
+ }
+ const TensorSpec &result() const { return _result; }
+ static TensorSpec ref(const vespalib::string &expr, const ParamRepo &param_repo) {
+ return EvalFixture(SimpleTensorEngine::ref(), expr, param_repo, false).result();
+ }
+ static TensorSpec prod(const vespalib::string &expr, const ParamRepo &param_repo) {
+ return EvalFixture(tensor::DefaultTensorEngine::ref(), expr, param_repo, true).result();
+ }
+};
+
+} // namespace vespalib::eval::test