summaryrefslogtreecommitdiffstats
path: root/eval/src/apps
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2018-01-18 13:31:57 +0000
committerHåvard Pettersen <havardpe@oath.com>2018-01-18 13:31:57 +0000
commit04ed2c4337ed3398bb110e9a845efcd2eb954577 (patch)
tree8c3b1779721a1ff61dd3e5255fce3159ad80bd45 /eval/src/apps
parent15b416446c7a6e674a14b5de0e0243fdab83f340 (diff)
run cross-language tensor conformance tests using tensor functions
also pass tensor engine to tensor function eval function
Diffstat (limited to 'eval/src/apps')
-rw-r--r--eval/src/apps/tensor_conformance/tensor_conformance.cpp17
1 files changed, 17 insertions, 0 deletions
diff --git a/eval/src/apps/tensor_conformance/tensor_conformance.cpp b/eval/src/apps/tensor_conformance/tensor_conformance.cpp
index 72fe61d7107..33c303f9574 100644
--- a/eval/src/apps/tensor_conformance/tensor_conformance.cpp
+++ b/eval/src/apps/tensor_conformance/tensor_conformance.cpp
@@ -98,6 +98,21 @@ TensorSpec eval_expr(const Inspector &test, const TensorEngine &engine, bool typ
return engine.to_spec(ifun.eval(ctx, params));
}
+TensorSpec eval_expr_tf(const Inspector &test, const TensorEngine &engine) {
+ Stash stash;
+ Function fun = Function::parse(test["expression"].asString().make_string());
+ std::vector<Value::UP> param_values;
+ std::vector<Value::CREF> param_refs;
+ for (size_t i = 0; i < fun.num_params(); ++i) {
+ param_values.emplace_back(engine.from_spec(extract_value(test["inputs"][fun.param_name(i)])));
+ param_refs.emplace_back(*param_values.back());
+ }
+ SimpleObjectParams params(param_refs);
+ NodeTypes types = NodeTypes(fun, get_types(param_values));
+ const auto &tfun = make_tensor_function(engine, fun.root(), types, stash);
+ return engine.to_spec(tfun.eval(engine, params, stash));
+}
+
//-----------------------------------------------------------------------------
std::vector<vespalib::string> extract_fields(const Inspector &object) {
@@ -164,6 +179,8 @@ void evaluate(Input &in, Output &out) {
eval_expr(slime.get(), DefaultTensorEngine::ref(), false));
insert_value(slime["result"], "cpp_ref_typed",
eval_expr(slime.get(), SimpleTensorEngine::ref(), true));
+ insert_value(slime["result"], "cpp_tensor_function",
+ eval_expr_tf(slime.get(), DefaultTensorEngine::ref()));
write_compact(slime, out);
};
auto handle_summary = [&out](Slime &slime)