diff options
Diffstat (limited to 'searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp')
-rw-r--r-- | searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp | 37 |
1 files changed, 33 insertions, 4 deletions
diff --git a/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp b/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp index 1cc8d0280f6..c46990732b7 100644 --- a/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp +++ b/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp @@ -4,6 +4,7 @@ #include <vespa/vespalib/util/stringfmt.h> #include <vespa/searchcommon/common/schema.h> #include <vespa/searchlib/fef/indexproperties.h> +#include <vespa/searchlib/fef/onnx_model.h> #include <string> #include <vector> #include <map> @@ -18,6 +19,7 @@ const char *invalid_feature = "invalid_feature_name and format"; using namespace search::fef::indexproperties; using namespace search::index; +using search::fef::OnnxModel; using search::index::schema::CollectionType; using search::index::schema::DataType; @@ -69,9 +71,12 @@ struct Setup { std::map<std::string,std::string> properties; std::map<std::string,std::string> constants; std::vector<bool> extra_profiles; - std::map<std::string,std::string> onnx_models; + std::map<std::string,OnnxModel> onnx_models; Setup(); ~Setup(); + void add_onnx_model(const OnnxModel &model) { + onnx_models.insert_or_assign(model.name(), model); + } void index(const std::string &name, schema::DataType data_type, schema::CollectionType collection_type) { @@ -155,8 +160,20 @@ struct Setup { void write_onnx_models(const Writer &out) { size_t idx = 0; for (const auto &entry: onnx_models) { - out.fmt("model[%zu].name \"%s\"\n", idx, entry.first.c_str()); + out.fmt("model[%zu].name \"%s\"\n", idx, entry.second.name().c_str()); out.fmt("model[%zu].fileref \"onnx_ref_%zu\"\n", idx, idx); + size_t idx2 = 0; + for (const auto &input: entry.second.inspect_input_features()) { + out.fmt("model[%zu].input[%zu].name \"%s\"\n", idx, idx2, input.first.c_str()); + out.fmt("model[%zu].input[%zu].source \"%s\"\n", idx, idx2, input.second.c_str()); + ++idx2; + } + idx2 = 0; + for (const auto &output: entry.second.inspect_output_names()) { + out.fmt("model[%zu].output[%zu].name \"%s\"\n", idx, idx2, output.first.c_str()); + out.fmt("model[%zu].output[%zu].as \"%s\"\n", idx, idx2, output.second.c_str()); + ++idx2; + } ++idx; } } @@ -164,7 +181,7 @@ struct Setup { size_t idx = 0; for (const auto &entry: onnx_models) { out.fmt("file[%zu].ref \"onnx_ref_%zu\"\n", idx, idx); - out.fmt("file[%zu].path \"%s\"\n", idx, entry.second.c_str()); + out.fmt("file[%zu].path \"%s\"\n", idx, entry.second.file_path().c_str()); ++idx; } } @@ -225,7 +242,12 @@ struct SimpleSetup : Setup { struct OnnxSetup : Setup { OnnxSetup() : Setup() { - onnx_models["simple"] = TEST_PATH("../../../../../eval/src/tests/tensor/onnx_wrapper/simple.onnx"); + add_onnx_model(OnnxModel("simple", TEST_PATH("../../../../../eval/src/tests/tensor/onnx_wrapper/simple.onnx"))); + add_onnx_model(OnnxModel("mapped", TEST_PATH("../../../../../eval/src/tests/tensor/onnx_wrapper/simple.onnx")) + .input_feature("query_tensor", "rankingExpression(qt)") + .input_feature("attribute_tensor", "rankingExpression(at)") + .input_feature("bias_tensor", "rankingExpression(bt)") + .output_name("output", "result")); } }; @@ -350,6 +372,13 @@ TEST_F("require that input type mismatch makes onnx model fail verification", On f.verify_invalid({"onnxModel(simple)"}); } +TEST_F("require that onnx model can have inputs and outputs mapped", OnnxSetup()) { + f.rank_expr("qt", "tensor<float>(a[1],b[4]):[[1,2,3,4]]"); + f.rank_expr("at", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]"); + f.rank_expr("bt", "tensor<float>(a[1],b[1]):[[9]]"); + f.verify_valid({"onnxModel(mapped).result"}); +} + //----------------------------------------------------------------------------- TEST_F("cleanup files", Setup()) { |