aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-09-22 08:28:02 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-09-22 10:15:54 +0000
commit9a66f21d375e4fa07a96069644615581e55129d5 (patch)
tree06891b9c7f91fd0cf673cd1c540cb0dcafd5e950
parent804e8057c2eca61ec9bc8985430613e0731922a2 (diff)
handle onnx model config for inputs and outputs
-rw-r--r--searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp1
-rw-r--r--searchcore/src/tests/proton/matching/index_environment/index_environment_test.cpp25
-rw-r--r--searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp37
-rw-r--r--searchcore/src/vespa/searchcore/proton/matching/indexenvironment.cpp9
-rw-r--r--searchcore/src/vespa/searchcore/proton/matching/indexenvironment.h2
-rw-r--r--searchcore/src/vespa/searchcore/proton/matching/onnx_models.cpp35
-rw-r--r--searchcore/src/vespa/searchcore/proton/matching/onnx_models.h15
-rw-r--r--searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp1
-rw-r--r--searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp26
-rw-r--r--searchlib/src/vespa/searchlib/features/onnx_feature.cpp29
-rw-r--r--searchlib/src/vespa/searchlib/fef/CMakeLists.txt3
-rw-r--r--searchlib/src/vespa/searchlib/fef/iindexenvironment.h6
-rw-r--r--searchlib/src/vespa/searchlib/fef/onnx_model.cpp55
-rw-r--r--searchlib/src/vespa/searchlib/fef/onnx_model.h39
-rw-r--r--searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp12
-rw-r--r--searchlib/src/vespa/searchlib/fef/test/indexenvironment.h7
-rw-r--r--streamingvisitors/src/vespa/searchvisitor/indexenvironment.h4
17 files changed, 230 insertions, 76 deletions
diff --git a/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp b/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp
index 2d028f47513..1d492cb558f 100644
--- a/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp
+++ b/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp
@@ -59,6 +59,7 @@ OnnxModels make_models(const OnnxModelsConfig &modelsCfg, const VerifyRanksetupC
for (const auto &entry: modelsCfg.model) {
if (auto file = get_file(entry.fileref, myCfg)) {
model_list.emplace_back(entry.name, file.value());
+ OnnxModels::configure(entry, model_list.back());
} else {
LOG(warning, "could not find file for onnx model '%s' (ref:'%s')\n",
entry.name.c_str(), entry.fileref.c_str());
diff --git a/searchcore/src/tests/proton/matching/index_environment/index_environment_test.cpp b/searchcore/src/tests/proton/matching/index_environment/index_environment_test.cpp
index 508a60480d0..421ebffafa4 100644
--- a/searchcore/src/tests/proton/matching/index_environment/index_environment_test.cpp
+++ b/searchcore/src/tests/proton/matching/index_environment/index_environment_test.cpp
@@ -8,6 +8,7 @@ using namespace proton::matching;
using search::fef::FieldInfo;
using search::fef::FieldType;
using search::fef::Properties;
+using search::fef::OnnxModel;
using search::index::Schema;
using search::index::schema::CollectionType;
using search::index::schema::DataType;
@@ -16,8 +17,8 @@ using SIAF = Schema::ImportedAttributeField;
OnnxModels make_models() {
OnnxModels::Vector list;
- list.emplace_back("model1", "path1");
- list.emplace_back("model2", "path2");
+ list.emplace_back(OnnxModel("model1", "path1").input_feature("input1","feature1").output_name("output1", "out1"));
+ list.emplace_back(OnnxModel("model2", "path2"));
return OnnxModels(list);
}
@@ -104,10 +105,22 @@ TEST_F("require that imported attribute fields are extracted in index environmen
EXPECT_EQUAL("[documentmetastore]", f.env.getField(2)->name());
}
-TEST_F("require that onnx model paths can be obtained", Fixture(buildEmptySchema())) {
- EXPECT_EQUAL(f1.env.getOnnxModelFullPath("model1").value(), vespalib::string("path1"));
- EXPECT_EQUAL(f1.env.getOnnxModelFullPath("model2").value(), vespalib::string("path2"));
- EXPECT_FALSE(f1.env.getOnnxModelFullPath("model3").has_value());
+TEST_F("require that onnx model config can be obtained", Fixture(buildEmptySchema())) {
+ {
+ auto model = f1.env.getOnnxModel("model1");
+ ASSERT_TRUE(model != nullptr);
+ EXPECT_EQUAL(model->file_path(), vespalib::string("path1"));
+ EXPECT_EQUAL(model->input_feature("input1").value(), vespalib::string("feature1"));
+ EXPECT_EQUAL(model->output_name("output1").value(), vespalib::string("out1"));
+ }
+ {
+ auto model = f1.env.getOnnxModel("model2");
+ ASSERT_TRUE(model != nullptr);
+ EXPECT_EQUAL(model->file_path(), vespalib::string("path2"));
+ EXPECT_FALSE(model->input_feature("input1").has_value());
+ EXPECT_FALSE(model->output_name("output1").has_value());
+ }
+ EXPECT_TRUE(f1.env.getOnnxModel("model3") == nullptr);
}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp b/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp
index 1cc8d0280f6..c46990732b7 100644
--- a/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp
+++ b/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp
@@ -4,6 +4,7 @@
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/searchcommon/common/schema.h>
#include <vespa/searchlib/fef/indexproperties.h>
+#include <vespa/searchlib/fef/onnx_model.h>
#include <string>
#include <vector>
#include <map>
@@ -18,6 +19,7 @@ const char *invalid_feature = "invalid_feature_name and format";
using namespace search::fef::indexproperties;
using namespace search::index;
+using search::fef::OnnxModel;
using search::index::schema::CollectionType;
using search::index::schema::DataType;
@@ -69,9 +71,12 @@ struct Setup {
std::map<std::string,std::string> properties;
std::map<std::string,std::string> constants;
std::vector<bool> extra_profiles;
- std::map<std::string,std::string> onnx_models;
+ std::map<std::string,OnnxModel> onnx_models;
Setup();
~Setup();
+ void add_onnx_model(const OnnxModel &model) {
+ onnx_models.insert_or_assign(model.name(), model);
+ }
void index(const std::string &name, schema::DataType data_type,
schema::CollectionType collection_type)
{
@@ -155,8 +160,20 @@ struct Setup {
void write_onnx_models(const Writer &out) {
size_t idx = 0;
for (const auto &entry: onnx_models) {
- out.fmt("model[%zu].name \"%s\"\n", idx, entry.first.c_str());
+ out.fmt("model[%zu].name \"%s\"\n", idx, entry.second.name().c_str());
out.fmt("model[%zu].fileref \"onnx_ref_%zu\"\n", idx, idx);
+ size_t idx2 = 0;
+ for (const auto &input: entry.second.inspect_input_features()) {
+ out.fmt("model[%zu].input[%zu].name \"%s\"\n", idx, idx2, input.first.c_str());
+ out.fmt("model[%zu].input[%zu].source \"%s\"\n", idx, idx2, input.second.c_str());
+ ++idx2;
+ }
+ idx2 = 0;
+ for (const auto &output: entry.second.inspect_output_names()) {
+ out.fmt("model[%zu].output[%zu].name \"%s\"\n", idx, idx2, output.first.c_str());
+ out.fmt("model[%zu].output[%zu].as \"%s\"\n", idx, idx2, output.second.c_str());
+ ++idx2;
+ }
++idx;
}
}
@@ -164,7 +181,7 @@ struct Setup {
size_t idx = 0;
for (const auto &entry: onnx_models) {
out.fmt("file[%zu].ref \"onnx_ref_%zu\"\n", idx, idx);
- out.fmt("file[%zu].path \"%s\"\n", idx, entry.second.c_str());
+ out.fmt("file[%zu].path \"%s\"\n", idx, entry.second.file_path().c_str());
++idx;
}
}
@@ -225,7 +242,12 @@ struct SimpleSetup : Setup {
struct OnnxSetup : Setup {
OnnxSetup() : Setup() {
- onnx_models["simple"] = TEST_PATH("../../../../../eval/src/tests/tensor/onnx_wrapper/simple.onnx");
+ add_onnx_model(OnnxModel("simple", TEST_PATH("../../../../../eval/src/tests/tensor/onnx_wrapper/simple.onnx")));
+ add_onnx_model(OnnxModel("mapped", TEST_PATH("../../../../../eval/src/tests/tensor/onnx_wrapper/simple.onnx"))
+ .input_feature("query_tensor", "rankingExpression(qt)")
+ .input_feature("attribute_tensor", "rankingExpression(at)")
+ .input_feature("bias_tensor", "rankingExpression(bt)")
+ .output_name("output", "result"));
}
};
@@ -350,6 +372,13 @@ TEST_F("require that input type mismatch makes onnx model fail verification", On
f.verify_invalid({"onnxModel(simple)"});
}
+TEST_F("require that onnx model can have inputs and outputs mapped", OnnxSetup()) {
+ f.rank_expr("qt", "tensor<float>(a[1],b[4]):[[1,2,3,4]]");
+ f.rank_expr("at", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]");
+ f.rank_expr("bt", "tensor<float>(a[1],b[1]):[[9]]");
+ f.verify_valid({"onnxModel(mapped).result"});
+}
+
//-----------------------------------------------------------------------------
TEST_F("cleanup files", Setup()) {
diff --git a/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.cpp b/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.cpp
index 5743a3d44d6..013f359c4f9 100644
--- a/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.cpp
+++ b/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.cpp
@@ -131,13 +131,10 @@ IndexEnvironment::hintFieldAccess(uint32_t ) const { }
void
IndexEnvironment::hintAttributeAccess(const string &) const { }
-std::optional<vespalib::string>
-IndexEnvironment::getOnnxModelFullPath(const vespalib::string &name) const
+const search::fef::OnnxModel *
+IndexEnvironment::getOnnxModel(const vespalib::string &name) const
{
- if (const auto model = _onnxModels.getModel(name)) {
- return model->filePath;
- }
- return std::nullopt;
+ return _onnxModels.getModel(name);
}
IndexEnvironment::~IndexEnvironment() = default;
diff --git a/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.h b/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.h
index d0e9a516cd0..ad51eb17b4d 100644
--- a/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.h
+++ b/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.h
@@ -69,7 +69,7 @@ public:
return _constantValueRepo.getConstant(name);
}
- std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const override;
+ const search::fef::OnnxModel *getOnnxModel(const vespalib::string &name) const override;
~IndexEnvironment() override;
};
diff --git a/searchcore/src/vespa/searchcore/proton/matching/onnx_models.cpp b/searchcore/src/vespa/searchcore/proton/matching/onnx_models.cpp
index bdcf3e21d8e..ed80ca28bd6 100644
--- a/searchcore/src/vespa/searchcore/proton/matching/onnx_models.cpp
+++ b/searchcore/src/vespa/searchcore/proton/matching/onnx_models.cpp
@@ -1,25 +1,10 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "onnx_models.h"
+#include <assert.h>
namespace proton::matching {
-OnnxModels::Model::Model(const vespalib::string &name_in,
- const vespalib::string &filePath_in)
- : name(name_in),
- filePath(filePath_in)
-{
-}
-
-OnnxModels::Model::~Model() = default;
-
-bool
-OnnxModels::Model::operator==(const Model &rhs) const
-{
- return (name == rhs.name) &&
- (filePath == rhs.filePath);
-}
-
OnnxModels::OnnxModels()
: _models()
{
@@ -30,15 +15,15 @@ OnnxModels::~OnnxModels() = default;
OnnxModels::OnnxModels(const Vector &models)
: _models()
{
- for (const auto &model : models) {
- _models.insert(std::make_pair(model.name, model));
+ for (const auto &model: models) {
+ _models.emplace(model.name(), model);
}
}
bool
OnnxModels::operator==(const OnnxModels &rhs) const
{
- return _models == rhs._models;
+ return (_models == rhs._models);
}
const OnnxModels::Model *
@@ -51,4 +36,16 @@ OnnxModels::getModel(const vespalib::string &name) const
return nullptr;
}
+void
+OnnxModels::configure(const ModelConfig &config, Model &model)
+{
+ assert(config.name == model.name());
+ for (const auto &input: config.input) {
+ model.input_feature(input.name, input.source);
+ }
+ for (const auto &output: config.output) {
+ model.output_name(output.name, output.as);
+ }
+}
+
}
diff --git a/searchcore/src/vespa/searchcore/proton/matching/onnx_models.h b/searchcore/src/vespa/searchcore/proton/matching/onnx_models.h
index fdaae657711..65ba524d8fc 100644
--- a/searchcore/src/vespa/searchcore/proton/matching/onnx_models.h
+++ b/searchcore/src/vespa/searchcore/proton/matching/onnx_models.h
@@ -3,6 +3,8 @@
#pragma once
#include <vespa/vespalib/stllike/string.h>
+#include <vespa/searchlib/fef/onnx_model.h>
+#include <vespa/searchcore/config/config-onnx-models.h>
#include <map>
#include <vector>
@@ -14,16 +16,8 @@ namespace proton::matching {
*/
class OnnxModels {
public:
- struct Model {
- vespalib::string name;
- vespalib::string filePath;
-
- Model(const vespalib::string &name_in,
- const vespalib::string &filePath_in);
- ~Model();
- bool operator==(const Model &rhs) const;
- };
-
+ using ModelConfig = vespa::config::search::core::OnnxModelsConfig::Model;
+ using Model = search::fef::OnnxModel;
using Vector = std::vector<Model>;
private:
@@ -38,6 +32,7 @@ public:
bool operator==(const OnnxModels &rhs) const;
const Model *getModel(const vespalib::string &name) const;
size_t size() const { return _models.size(); }
+ static void configure(const ModelConfig &config, Model &model);
};
}
diff --git a/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp b/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp
index a8996abc856..c8b701e82f8 100644
--- a/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp
+++ b/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp
@@ -321,6 +321,7 @@ DocumentDBConfigManager::update(const ConfigSnapshot &snapshot)
LOG(info, "Got file path from file acquirer: '%s' (name='%s', ref='%s')",
filePath.c_str(), rc.name.c_str(), rc.fileref.c_str());
models.emplace_back(rc.name, filePath);
+ OnnxModels::configure(rc, models.back());
}
}
newOnnxModels = std::make_shared<OnnxModels>(models);
diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
index b49d9c365de..6a1e4ef9fa1 100644
--- a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
+++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
@@ -58,8 +58,8 @@ struct OnnxFeatureTest : ::testing::Test {
vespalib::string expr_name = feature_name + ".rankingScript";
indexEnv.getProperties().add(expr_name, expr);
}
- void add_onnx(const vespalib::string &name, const vespalib::string &file) {
- indexEnv.addOnnxModel(name, file);
+ void add_onnx(const OnnxModel &model) {
+ indexEnv.addOnnxModel(model);
}
void compile(const vespalib::string &seed) {
resolver->addSeed(seed);
@@ -89,7 +89,7 @@ TEST_F(OnnxFeatureTest, simple_onnx_model_can_be_calculated) {
add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]");
add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]");
add_expr("bias_tensor", "tensor<float>(a[1],b[1]):[[9]]");
- add_onnx("simple", simple_model);
+ add_onnx(OnnxModel("simple", simple_model));
compile(onnx_feature("simple"));
EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
EXPECT_EQ(get("onnxModel(simple).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
@@ -101,7 +101,7 @@ TEST_F(OnnxFeatureTest, dynamic_onnx_model_can_be_calculated) {
add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]");
add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]");
add_expr("bias_tensor", "tensor<float>(a[1],b[2]):[[4,5]]");
- add_onnx("dynamic", dynamic_model);
+ add_onnx(OnnxModel("dynamic", dynamic_model));
compile(onnx_feature("dynamic"));
EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
EXPECT_EQ(get("onnxModel(dynamic).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
@@ -112,7 +112,7 @@ TEST_F(OnnxFeatureTest, dynamic_onnx_model_can_be_calculated) {
TEST_F(OnnxFeatureTest, strange_input_and_output_names_are_normalized) {
add_expr("input_0", "tensor<float>(a[2]):[10,20]");
add_expr("input_1", "tensor<float>(a[2]):[5,10]");
- add_onnx("strange_names", strange_names_model);
+ add_onnx(OnnxModel("strange_names", strange_names_model));
compile(onnx_feature("strange_names"));
auto expect_add = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},15).add({{"d0",1}},30);
auto expect_sub = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},5).add({{"d0",1}},10);
@@ -121,4 +121,20 @@ TEST_F(OnnxFeatureTest, strange_input_and_output_names_are_normalized) {
EXPECT_EQ(get("onnxModel(strange_names)._baz_0", 1), expect_sub);
}
+TEST_F(OnnxFeatureTest, input_features_and_output_names_can_be_specified) {
+ add_expr("my_first_input", "tensor<float>(a[2]):[10,20]");
+ add_expr("my_second_input", "tensor<float>(a[2]):[5,10]");
+ add_onnx(OnnxModel("custom_names", strange_names_model)
+ .input_feature("input:0", "rankingExpression(my_first_input)")
+ .input_feature("input/1", "rankingExpression(my_second_input)")
+ .output_name("foo/bar", "my_first_output")
+ .output_name("-baz:0", "my_second_output"));
+ compile(onnx_feature("custom_names"));
+ auto expect_add = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},15).add({{"d0",1}},30);
+ auto expect_sub = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},5).add({{"d0",1}},10);
+ EXPECT_EQ(get(1), expect_add);
+ EXPECT_EQ(get("onnxModel(custom_names).my_first_output", 1), expect_add);
+ EXPECT_EQ(get("onnxModel(custom_names).my_second_output", 1), expect_sub);
+}
+
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
index 698d2309e5a..fca8988ba36 100644
--- a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
+++ b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
@@ -2,6 +2,7 @@
#include "onnx_feature.h"
#include <vespa/searchlib/fef/properties.h>
+#include <vespa/searchlib/fef/onnx_model.h>
#include <vespa/searchlib/fef/featureexecutor.h>
#include <vespa/eval/tensor/dense/dense_tensor_view.h>
#include <vespa/eval/tensor/dense/mutable_dense_tensor_view.h>
@@ -85,37 +86,45 @@ OnnxBlueprint::setup(const IIndexEnvironment &env,
auto optimize = (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP)
? Onnx::Optimize::DISABLE
: Onnx::Optimize::ENABLE;
- auto file_name = env.getOnnxModelFullPath(params[0].getValue());
- if (!file_name.has_value()) {
+ auto model_cfg = env.getOnnxModel(params[0].getValue());
+ if (!model_cfg) {
return fail("no model with name '%s' found", params[0].getValue().c_str());
}
try {
- _model = std::make_unique<Onnx>(file_name.value(), optimize);
+ _model = std::make_unique<Onnx>(model_cfg->file_path(), optimize);
} catch (std::exception &ex) {
return fail("model setup failed: %s", ex.what());
}
Onnx::WirePlanner planner;
for (size_t i = 0; i < _model->inputs().size(); ++i) {
const auto &model_input = _model->inputs()[i];
- vespalib::string input_name = normalize_name(model_input.name, "input");
- if (auto maybe_input = defineInput(fmt("rankingExpression(\"%s\")", input_name.c_str()), AcceptInput::OBJECT)) {
+ auto input_feature = model_cfg->input_feature(model_input.name);
+ if (!input_feature.has_value()) {
+ input_feature = fmt("rankingExpression(\"%s\")", normalize_name(model_input.name, "input").c_str());
+ }
+ if (auto maybe_input = defineInput(input_feature.value(), AcceptInput::OBJECT)) {
const FeatureType &feature_input = maybe_input.value();
assert(feature_input.is_object());
if (!planner.bind_input_type(feature_input.type(), model_input)) {
- return fail("incompatible type for input '%s': %s -> %s", input_name.c_str(),
+ return fail("incompatible type for input (%s -> %s): %s -> %s",
+ input_feature.value().c_str(), model_input.name.c_str(),
feature_input.type().to_spec().c_str(), model_input.type_as_string().c_str());
}
}
}
for (size_t i = 0; i < _model->outputs().size(); ++i) {
const auto &model_output = _model->outputs()[i];
- vespalib::string output_name = normalize_name(model_output.name, "output");
+ auto output_name = model_cfg->output_name(model_output.name);
+ if (!output_name.has_value()) {
+ output_name = normalize_name(model_output.name, "output");
+ }
ValueType output_type = planner.make_output_type(model_output);
if (output_type.is_error()) {
- return fail("unable to make compatible type for output '%s': %s -> error",
- output_name.c_str(), model_output.type_as_string().c_str());
+ return fail("unable to make compatible type for output (%s -> %s): %s -> error",
+ model_output.name.c_str(), output_name.value().c_str(),
+ model_output.type_as_string().c_str());
}
- describeOutput(output_name, "output from onnx model", FeatureType::object(output_type));
+ describeOutput(output_name.value(), "output from onnx model", FeatureType::object(output_type));
}
_wire_info = planner.get_wire_info(*_model);
return true;
diff --git a/searchlib/src/vespa/searchlib/fef/CMakeLists.txt b/searchlib/src/vespa/searchlib/fef/CMakeLists.txt
index 178de1b8b87..d6f8764cd63 100644
--- a/searchlib/src/vespa/searchlib/fef/CMakeLists.txt
+++ b/searchlib/src/vespa/searchlib/fef/CMakeLists.txt
@@ -4,12 +4,12 @@ vespa_add_library(searchlib_fef OBJECT
blueprint.cpp
blueprintfactory.cpp
blueprintresolver.cpp
+ feature_resolver.cpp
feature_type.cpp
featureexecutor.cpp
featurenamebuilder.cpp
featurenameparser.cpp
featureoverrider.cpp
- feature_resolver.cpp
fef.cpp
fieldinfo.cpp
fieldpositionsiterator.cpp
@@ -19,6 +19,7 @@ vespa_add_library(searchlib_fef OBJECT
matchdata.cpp
matchdatalayout.cpp
objectstore.cpp
+ onnx_model.cpp
parameter.cpp
parameterdescriptions.cpp
parametervalidator.cpp
diff --git a/searchlib/src/vespa/searchlib/fef/iindexenvironment.h b/searchlib/src/vespa/searchlib/fef/iindexenvironment.h
index 26e88a98033..384b81643cc 100644
--- a/searchlib/src/vespa/searchlib/fef/iindexenvironment.h
+++ b/searchlib/src/vespa/searchlib/fef/iindexenvironment.h
@@ -3,7 +3,6 @@
#pragma once
#include <vespa/vespalib/stllike/string.h>
-#include <optional>
namespace vespalib::eval { struct ConstantValue; }
@@ -12,6 +11,7 @@ namespace search::fef {
class Properties;
class FieldInfo;
class ITableManager;
+class OnnxModel;
/**
* Abstract view of index related information available to the
@@ -122,9 +122,9 @@ public:
virtual std::unique_ptr<vespalib::eval::ConstantValue> getConstantValue(const vespalib::string &name) const = 0;
/**
- * Get the full path of the file containing the given onnx model
+ * Get configuration for the given onnx model.
**/
- virtual std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const = 0;
+ virtual const OnnxModel *getOnnxModel(const vespalib::string &name) const = 0;
virtual uint32_t getDistributionKey() const = 0;
diff --git a/searchlib/src/vespa/searchlib/fef/onnx_model.cpp b/searchlib/src/vespa/searchlib/fef/onnx_model.cpp
new file mode 100644
index 00000000000..ba5adaae857
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/fef/onnx_model.cpp
@@ -0,0 +1,55 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "onnx_model.h"
+#include <tuple>
+
+namespace search::fef {
+
+OnnxModel::OnnxModel(const vespalib::string &name_in,
+ const vespalib::string &file_path_in)
+ : _name(name_in),
+ _file_path(file_path_in),
+ _input_features(),
+ _output_names()
+{
+}
+
+OnnxModel &
+OnnxModel::input_feature(const vespalib::string &model_input_name, const vespalib::string &input_feature) {
+ _input_features[model_input_name] = input_feature;
+ return *this;
+}
+
+OnnxModel &
+OnnxModel::output_name(const vespalib::string &model_output_name, const vespalib::string &output_name) {
+ _output_names[model_output_name] = output_name;
+ return *this;
+}
+
+std::optional<vespalib::string>
+OnnxModel::input_feature(const vespalib::string &model_input_name) const {
+ auto pos = _input_features.find(model_input_name);
+ if (pos != _input_features.end()) {
+ return pos->second;
+ }
+ return std::nullopt;
+}
+
+std::optional<vespalib::string>
+OnnxModel::output_name(const vespalib::string &model_output_name) const {
+ auto pos = _output_names.find(model_output_name);
+ if (pos != _output_names.end()) {
+ return pos->second;
+ }
+ return std::nullopt;
+}
+
+bool
+OnnxModel::operator==(const OnnxModel &rhs) const {
+ return (std::tie(_name, _file_path, _input_features, _output_names) ==
+ std::tie(rhs._name, rhs._file_path, rhs._input_features, rhs._output_names));
+}
+
+OnnxModel::~OnnxModel() = default;
+
+}
diff --git a/searchlib/src/vespa/searchlib/fef/onnx_model.h b/searchlib/src/vespa/searchlib/fef/onnx_model.h
new file mode 100644
index 00000000000..2195a50600d
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/fef/onnx_model.h
@@ -0,0 +1,39 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <vespa/vespalib/stllike/string.h>
+#include <optional>
+#include <map>
+
+namespace search::fef {
+
+/**
+ * Class containing configuration for a single onnx model setup. This
+ * class is used both by the IIndexEnvironment api as well as the
+ * OnnxModels config adapter in the search core (matching component).
+ **/
+class OnnxModel {
+private:
+ vespalib::string _name;
+ vespalib::string _file_path;
+ std::map<vespalib::string,vespalib::string> _input_features;
+ std::map<vespalib::string,vespalib::string> _output_names;
+
+public:
+ OnnxModel(const vespalib::string &name_in,
+ const vespalib::string &file_path_in);
+ ~OnnxModel();
+
+ const vespalib::string &name() const { return _name; }
+ const vespalib::string &file_path() const { return _file_path; }
+ OnnxModel &input_feature(const vespalib::string &model_input_name, const vespalib::string &input_feature);
+ OnnxModel &output_name(const vespalib::string &model_output_name, const vespalib::string &output_name);
+ std::optional<vespalib::string> input_feature(const vespalib::string &model_input_name) const;
+ std::optional<vespalib::string> output_name(const vespalib::string &model_output_name) const;
+ bool operator==(const OnnxModel &rhs) const;
+ const std::map<vespalib::string,vespalib::string> &inspect_input_features() const { return _input_features; }
+ const std::map<vespalib::string,vespalib::string> &inspect_output_names() const { return _output_names; }
+};
+
+}
diff --git a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp
index 6e2e0b88fbb..d2d336dcdc8 100644
--- a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp
+++ b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp
@@ -54,20 +54,20 @@ IndexEnvironment::addConstantValue(const vespalib::string &name,
(void) insertRes;
}
-std::optional<vespalib::string>
-IndexEnvironment::getOnnxModelFullPath(const vespalib::string &name) const
+const OnnxModel *
+IndexEnvironment::getOnnxModel(const vespalib::string &name) const
{
auto pos = _models.find(name);
if (pos != _models.end()) {
- return pos->second;
+ return &pos->second;
}
- return std::nullopt;
+ return nullptr;
}
void
-IndexEnvironment::addOnnxModel(const vespalib::string &name, const vespalib::string &path)
+IndexEnvironment::addOnnxModel(const OnnxModel &model)
{
- _models[name] = path;
+ _models.insert_or_assign(model.name(), model);
}
diff --git a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h
index 6602d9f8ee9..0d8d0091921 100644
--- a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h
+++ b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h
@@ -5,6 +5,7 @@
#include <vespa/searchlib/attribute/attributemanager.h>
#include <vespa/searchlib/fef/iindexenvironment.h>
#include <vespa/searchlib/fef/properties.h>
+#include <vespa/searchlib/fef/onnx_model.h>
#include <vespa/searchlib/fef/fieldinfo.h>
#include <vespa/searchlib/fef/tablemanager.h>
#include <vespa/eval/eval/value_cache/constant_value.h>
@@ -47,7 +48,7 @@ public:
};
using ConstantsMap = std::map<vespalib::string, Constant>;
- using ModelMap = std::map<vespalib::string, vespalib::string>;
+ using ModelMap = std::map<vespalib::string, OnnxModel>;
IndexEnvironment();
~IndexEnvironment();
@@ -84,8 +85,8 @@ public:
vespalib::eval::ValueType type,
std::unique_ptr<vespalib::eval::Value> value);
- std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const override;
- void addOnnxModel(const vespalib::string &name, const vespalib::string &path);
+ const OnnxModel *getOnnxModel(const vespalib::string &name) const override;
+ void addOnnxModel(const OnnxModel &model);
private:
IndexEnvironment(const IndexEnvironment &); // hide
diff --git a/streamingvisitors/src/vespa/searchvisitor/indexenvironment.h b/streamingvisitors/src/vespa/searchvisitor/indexenvironment.h
index 3bbfb0b23f9..dc7be36c290 100644
--- a/streamingvisitors/src/vespa/searchvisitor/indexenvironment.h
+++ b/streamingvisitors/src/vespa/searchvisitor/indexenvironment.h
@@ -73,8 +73,8 @@ public:
return vespalib::eval::ConstantValue::UP();
}
- std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &) const override {
- return std::nullopt;
+ const search::fef::OnnxModel *getOnnxModel(const vespalib::string &) const override {
+ return nullptr;
}
bool addField(const vespalib::string & name, bool isAttribute);