diff options
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/eval/test/CMakeLists.txt | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/test/eval_fixture.cpp | 72 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/test/eval_fixture.h | 84 |
3 files changed, 157 insertions, 0 deletions
diff --git a/eval/src/vespa/eval/eval/test/CMakeLists.txt b/eval/src/vespa/eval/eval/test/CMakeLists.txt index f27c689e8b6..8b4f7c4f93b 100644 --- a/eval/src/vespa/eval/eval/test/CMakeLists.txt +++ b/eval/src/vespa/eval/eval/test/CMakeLists.txt @@ -1,6 +1,7 @@ # Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. vespa_add_library(eval_eval_test OBJECT SOURCES + eval_fixture.cpp eval_spec.cpp tensor_conformance.cpp test_io.cpp diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.cpp b/eval/src/vespa/eval/eval/test/eval_fixture.cpp new file mode 100644 index 00000000000..b7e3764f6ce --- /dev/null +++ b/eval/src/vespa/eval/eval/test/eval_fixture.cpp @@ -0,0 +1,72 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/testkit/test_kit.h> +#include "eval_fixture.h" + +namespace vespalib::eval::test { + +using ParamRepo = EvalFixture::ParamRepo; + +namespace { + +NodeTypes get_types(const Function &function, const ParamRepo ¶m_repo) { + std::vector<ValueType> param_types; + for (size_t i = 0; i < function.num_params(); ++i) { + auto pos = param_repo.map.find(function.param_name(i)); + ASSERT_TRUE(pos != param_repo.map.end()); + param_types.push_back(ValueType::from_spec(pos->second.type)); + ASSERT_TRUE(!param_types.back().is_error()); + } + return NodeTypes(function, param_types); +} + +const TensorFunction &make_tfun(bool optimized, const TensorEngine &engine, const Function &function, + const NodeTypes &node_types, Stash &stash) +{ + const TensorFunction &plain_fun = make_tensor_function(engine, function.root(), node_types, stash); + return optimized ? engine.optimize(plain_fun, stash) : plain_fun; +} + +std::vector<Value::UP> make_params(const TensorEngine &engine, const Function &function, + const ParamRepo ¶m_repo) +{ + std::vector<Value::UP> result; + for (size_t i = 0; i < function.num_params(); ++i) { + auto pos = param_repo.map.find(function.param_name(i)); + ASSERT_TRUE(pos != param_repo.map.end()); + result.push_back(engine.from_spec(pos->second.value)); + ASSERT_TRUE(!result.back()->type().is_abstract()); + } + return result; +} + +std::vector<Value::CREF> get_refs(const std::vector<Value::UP> &values) { + std::vector<Value::CREF> result; + for (const auto &value: values) { + result.emplace_back(*value); + } + return result; +} + +} // namespace vespalib::eval::test + +EvalFixture::EvalFixture(const TensorEngine &engine, + const vespalib::string &expr, + const ParamRepo ¶m_repo, + bool optimized) + : _engine(engine), + _stash(), + _function(Function::parse(expr)), + _node_types(get_types(_function, param_repo)), + _tensor_function(make_tfun(optimized, _engine, _function, _node_types, _stash)), + _ifun(_engine, _tensor_function), + _ictx(_ifun), + _param_values(make_params(_engine, _function, param_repo)), + _params(get_refs(_param_values)), + _result(_engine.to_spec(_ifun.eval(_ictx, _params))) +{ + auto result_type = ValueType::from_spec(_result.type()); + ASSERT_TRUE(!result_type.is_error()); +} + +} // namespace vespalib::eval::test diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.h b/eval/src/vespa/eval/eval/test/eval_fixture.h new file mode 100644 index 00000000000..1f864e980cc --- /dev/null +++ b/eval/src/vespa/eval/eval/test/eval_fixture.h @@ -0,0 +1,84 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/eval/eval/function.h> +#include <vespa/eval/eval/tensor_spec.h> +#include <vespa/eval/eval/tensor_function.h> +#include <vespa/eval/eval/interpreted_function.h> +#include <vespa/eval/eval/simple_tensor_engine.h> +#include <vespa/eval/tensor/default_tensor_engine.h> +#include <vespa/vespalib/util/stash.h> + +namespace vespalib::eval::test { + +class EvalFixture +{ +public: + struct Param { + TensorSpec value; // actual parameter value + vespalib::string type; // pre-defined type (could be abstract) + Param(TensorSpec value_in) + : value(std::move(value_in)), type(value.type()) {} + Param(TensorSpec value_in, const vespalib::string &type_in) + : value(std::move(value_in)), type(type_in) {} + ~Param() {} + }; + + struct ParamRepo { + std::map<vespalib::string,Param> map; + ParamRepo() : map() {} + ParamRepo &add(const vespalib::string &name, TensorSpec value_in) { + map.insert_or_assign(name, Param(std::move(value_in))); + return *this; + } + ParamRepo &add(const vespalib::string &name, TensorSpec value_in, const vespalib::string &type_in) { + map.insert_or_assign(name, Param(std::move(value_in), type_in)); + return *this; + } + ~ParamRepo() {} + }; + +private: + const TensorEngine &_engine; + Stash _stash; + Function _function; + NodeTypes _node_types; + const TensorFunction &_tensor_function; + InterpretedFunction _ifun; + InterpretedFunction::Context _ictx; + std::vector<Value::UP> _param_values; + SimpleObjectParams _params; + TensorSpec _result; + + template <typename T> + void find_all(const TensorFunction &node, std::vector<const T *> &list) { + if (auto self = as<T>(node)) { + list.push_back(self); + } + std::vector<TensorFunction::Child::CREF> children; + node.push_children(children); + for (const auto &child: children) { + find_all(child.get().get(), list); + } + } + +public: + EvalFixture(const TensorEngine &engine, const vespalib::string &expr, const ParamRepo ¶m_repo, bool optimized); + ~EvalFixture() {} + template <typename T> + std::vector<const T *> find_all() { + std::vector<const T *> list; + find_all(_tensor_function, list); + return list; + } + const TensorSpec &result() const { return _result; } + static TensorSpec ref(const vespalib::string &expr, const ParamRepo ¶m_repo) { + return EvalFixture(SimpleTensorEngine::ref(), expr, param_repo, false).result(); + } + static TensorSpec prod(const vespalib::string &expr, const ParamRepo ¶m_repo) { + return EvalFixture(tensor::DefaultTensorEngine::ref(), expr, param_repo, true).result(); + } +}; + +} // namespace vespalib::eval::test |