aboutsummaryrefslogtreecommitdiffstats
path: root/searchcore/src/tests/proton/verify_ranksetup
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-09-22 08:28:02 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-09-22 10:15:54 +0000
commit9a66f21d375e4fa07a96069644615581e55129d5 (patch)
tree06891b9c7f91fd0cf673cd1c540cb0dcafd5e950 /searchcore/src/tests/proton/verify_ranksetup
parent804e8057c2eca61ec9bc8985430613e0731922a2 (diff)
handle onnx model config for inputs and outputs
Diffstat (limited to 'searchcore/src/tests/proton/verify_ranksetup')
-rw-r--r--searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp37
1 files changed, 33 insertions, 4 deletions
diff --git a/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp b/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp
index 1cc8d0280f6..c46990732b7 100644
--- a/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp
+++ b/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp
@@ -4,6 +4,7 @@
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/searchcommon/common/schema.h>
#include <vespa/searchlib/fef/indexproperties.h>
+#include <vespa/searchlib/fef/onnx_model.h>
#include <string>
#include <vector>
#include <map>
@@ -18,6 +19,7 @@ const char *invalid_feature = "invalid_feature_name and format";
using namespace search::fef::indexproperties;
using namespace search::index;
+using search::fef::OnnxModel;
using search::index::schema::CollectionType;
using search::index::schema::DataType;
@@ -69,9 +71,12 @@ struct Setup {
std::map<std::string,std::string> properties;
std::map<std::string,std::string> constants;
std::vector<bool> extra_profiles;
- std::map<std::string,std::string> onnx_models;
+ std::map<std::string,OnnxModel> onnx_models;
Setup();
~Setup();
+ void add_onnx_model(const OnnxModel &model) {
+ onnx_models.insert_or_assign(model.name(), model);
+ }
void index(const std::string &name, schema::DataType data_type,
schema::CollectionType collection_type)
{
@@ -155,8 +160,20 @@ struct Setup {
void write_onnx_models(const Writer &out) {
size_t idx = 0;
for (const auto &entry: onnx_models) {
- out.fmt("model[%zu].name \"%s\"\n", idx, entry.first.c_str());
+ out.fmt("model[%zu].name \"%s\"\n", idx, entry.second.name().c_str());
out.fmt("model[%zu].fileref \"onnx_ref_%zu\"\n", idx, idx);
+ size_t idx2 = 0;
+ for (const auto &input: entry.second.inspect_input_features()) {
+ out.fmt("model[%zu].input[%zu].name \"%s\"\n", idx, idx2, input.first.c_str());
+ out.fmt("model[%zu].input[%zu].source \"%s\"\n", idx, idx2, input.second.c_str());
+ ++idx2;
+ }
+ idx2 = 0;
+ for (const auto &output: entry.second.inspect_output_names()) {
+ out.fmt("model[%zu].output[%zu].name \"%s\"\n", idx, idx2, output.first.c_str());
+ out.fmt("model[%zu].output[%zu].as \"%s\"\n", idx, idx2, output.second.c_str());
+ ++idx2;
+ }
++idx;
}
}
@@ -164,7 +181,7 @@ struct Setup {
size_t idx = 0;
for (const auto &entry: onnx_models) {
out.fmt("file[%zu].ref \"onnx_ref_%zu\"\n", idx, idx);
- out.fmt("file[%zu].path \"%s\"\n", idx, entry.second.c_str());
+ out.fmt("file[%zu].path \"%s\"\n", idx, entry.second.file_path().c_str());
++idx;
}
}
@@ -225,7 +242,12 @@ struct SimpleSetup : Setup {
struct OnnxSetup : Setup {
OnnxSetup() : Setup() {
- onnx_models["simple"] = TEST_PATH("../../../../../eval/src/tests/tensor/onnx_wrapper/simple.onnx");
+ add_onnx_model(OnnxModel("simple", TEST_PATH("../../../../../eval/src/tests/tensor/onnx_wrapper/simple.onnx")));
+ add_onnx_model(OnnxModel("mapped", TEST_PATH("../../../../../eval/src/tests/tensor/onnx_wrapper/simple.onnx"))
+ .input_feature("query_tensor", "rankingExpression(qt)")
+ .input_feature("attribute_tensor", "rankingExpression(at)")
+ .input_feature("bias_tensor", "rankingExpression(bt)")
+ .output_name("output", "result"));
}
};
@@ -350,6 +372,13 @@ TEST_F("require that input type mismatch makes onnx model fail verification", On
f.verify_invalid({"onnxModel(simple)"});
}
+TEST_F("require that onnx model can have inputs and outputs mapped", OnnxSetup()) {
+ f.rank_expr("qt", "tensor<float>(a[1],b[4]):[[1,2,3,4]]");
+ f.rank_expr("at", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]");
+ f.rank_expr("bt", "tensor<float>(a[1],b[1]):[[9]]");
+ f.verify_valid({"onnxModel(mapped).result"});
+}
+
//-----------------------------------------------------------------------------
TEST_F("cleanup files", Setup()) {