diff options
author | Håvard Pettersen <havardpe@oath.com> | 2020-09-22 08:28:02 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2020-09-22 10:15:54 +0000 |
commit | 9a66f21d375e4fa07a96069644615581e55129d5 (patch) | |
tree | 06891b9c7f91fd0cf673cd1c540cb0dcafd5e950 | |
parent | 804e8057c2eca61ec9bc8985430613e0731922a2 (diff) |
handle onnx model config for inputs and outputs
17 files changed, 230 insertions, 76 deletions
diff --git a/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp b/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp index 2d028f47513..1d492cb558f 100644 --- a/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp +++ b/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp @@ -59,6 +59,7 @@ OnnxModels make_models(const OnnxModelsConfig &modelsCfg, const VerifyRanksetupC for (const auto &entry: modelsCfg.model) { if (auto file = get_file(entry.fileref, myCfg)) { model_list.emplace_back(entry.name, file.value()); + OnnxModels::configure(entry, model_list.back()); } else { LOG(warning, "could not find file for onnx model '%s' (ref:'%s')\n", entry.name.c_str(), entry.fileref.c_str()); diff --git a/searchcore/src/tests/proton/matching/index_environment/index_environment_test.cpp b/searchcore/src/tests/proton/matching/index_environment/index_environment_test.cpp index 508a60480d0..421ebffafa4 100644 --- a/searchcore/src/tests/proton/matching/index_environment/index_environment_test.cpp +++ b/searchcore/src/tests/proton/matching/index_environment/index_environment_test.cpp @@ -8,6 +8,7 @@ using namespace proton::matching; using search::fef::FieldInfo; using search::fef::FieldType; using search::fef::Properties; +using search::fef::OnnxModel; using search::index::Schema; using search::index::schema::CollectionType; using search::index::schema::DataType; @@ -16,8 +17,8 @@ using SIAF = Schema::ImportedAttributeField; OnnxModels make_models() { OnnxModels::Vector list; - list.emplace_back("model1", "path1"); - list.emplace_back("model2", "path2"); + list.emplace_back(OnnxModel("model1", "path1").input_feature("input1","feature1").output_name("output1", "out1")); + list.emplace_back(OnnxModel("model2", "path2")); return OnnxModels(list); } @@ -104,10 +105,22 @@ TEST_F("require that imported attribute fields are extracted in index environmen EXPECT_EQUAL("[documentmetastore]", f.env.getField(2)->name()); } -TEST_F("require that onnx model paths can be obtained", Fixture(buildEmptySchema())) { - EXPECT_EQUAL(f1.env.getOnnxModelFullPath("model1").value(), vespalib::string("path1")); - EXPECT_EQUAL(f1.env.getOnnxModelFullPath("model2").value(), vespalib::string("path2")); - EXPECT_FALSE(f1.env.getOnnxModelFullPath("model3").has_value()); +TEST_F("require that onnx model config can be obtained", Fixture(buildEmptySchema())) { + { + auto model = f1.env.getOnnxModel("model1"); + ASSERT_TRUE(model != nullptr); + EXPECT_EQUAL(model->file_path(), vespalib::string("path1")); + EXPECT_EQUAL(model->input_feature("input1").value(), vespalib::string("feature1")); + EXPECT_EQUAL(model->output_name("output1").value(), vespalib::string("out1")); + } + { + auto model = f1.env.getOnnxModel("model2"); + ASSERT_TRUE(model != nullptr); + EXPECT_EQUAL(model->file_path(), vespalib::string("path2")); + EXPECT_FALSE(model->input_feature("input1").has_value()); + EXPECT_FALSE(model->output_name("output1").has_value()); + } + EXPECT_TRUE(f1.env.getOnnxModel("model3") == nullptr); } TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp b/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp index 1cc8d0280f6..c46990732b7 100644 --- a/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp +++ b/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp @@ -4,6 +4,7 @@ #include <vespa/vespalib/util/stringfmt.h> #include <vespa/searchcommon/common/schema.h> #include <vespa/searchlib/fef/indexproperties.h> +#include <vespa/searchlib/fef/onnx_model.h> #include <string> #include <vector> #include <map> @@ -18,6 +19,7 @@ const char *invalid_feature = "invalid_feature_name and format"; using namespace search::fef::indexproperties; using namespace search::index; +using search::fef::OnnxModel; using search::index::schema::CollectionType; using search::index::schema::DataType; @@ -69,9 +71,12 @@ struct Setup { std::map<std::string,std::string> properties; std::map<std::string,std::string> constants; std::vector<bool> extra_profiles; - std::map<std::string,std::string> onnx_models; + std::map<std::string,OnnxModel> onnx_models; Setup(); ~Setup(); + void add_onnx_model(const OnnxModel &model) { + onnx_models.insert_or_assign(model.name(), model); + } void index(const std::string &name, schema::DataType data_type, schema::CollectionType collection_type) { @@ -155,8 +160,20 @@ struct Setup { void write_onnx_models(const Writer &out) { size_t idx = 0; for (const auto &entry: onnx_models) { - out.fmt("model[%zu].name \"%s\"\n", idx, entry.first.c_str()); + out.fmt("model[%zu].name \"%s\"\n", idx, entry.second.name().c_str()); out.fmt("model[%zu].fileref \"onnx_ref_%zu\"\n", idx, idx); + size_t idx2 = 0; + for (const auto &input: entry.second.inspect_input_features()) { + out.fmt("model[%zu].input[%zu].name \"%s\"\n", idx, idx2, input.first.c_str()); + out.fmt("model[%zu].input[%zu].source \"%s\"\n", idx, idx2, input.second.c_str()); + ++idx2; + } + idx2 = 0; + for (const auto &output: entry.second.inspect_output_names()) { + out.fmt("model[%zu].output[%zu].name \"%s\"\n", idx, idx2, output.first.c_str()); + out.fmt("model[%zu].output[%zu].as \"%s\"\n", idx, idx2, output.second.c_str()); + ++idx2; + } ++idx; } } @@ -164,7 +181,7 @@ struct Setup { size_t idx = 0; for (const auto &entry: onnx_models) { out.fmt("file[%zu].ref \"onnx_ref_%zu\"\n", idx, idx); - out.fmt("file[%zu].path \"%s\"\n", idx, entry.second.c_str()); + out.fmt("file[%zu].path \"%s\"\n", idx, entry.second.file_path().c_str()); ++idx; } } @@ -225,7 +242,12 @@ struct SimpleSetup : Setup { struct OnnxSetup : Setup { OnnxSetup() : Setup() { - onnx_models["simple"] = TEST_PATH("../../../../../eval/src/tests/tensor/onnx_wrapper/simple.onnx"); + add_onnx_model(OnnxModel("simple", TEST_PATH("../../../../../eval/src/tests/tensor/onnx_wrapper/simple.onnx"))); + add_onnx_model(OnnxModel("mapped", TEST_PATH("../../../../../eval/src/tests/tensor/onnx_wrapper/simple.onnx")) + .input_feature("query_tensor", "rankingExpression(qt)") + .input_feature("attribute_tensor", "rankingExpression(at)") + .input_feature("bias_tensor", "rankingExpression(bt)") + .output_name("output", "result")); } }; @@ -350,6 +372,13 @@ TEST_F("require that input type mismatch makes onnx model fail verification", On f.verify_invalid({"onnxModel(simple)"}); } +TEST_F("require that onnx model can have inputs and outputs mapped", OnnxSetup()) { + f.rank_expr("qt", "tensor<float>(a[1],b[4]):[[1,2,3,4]]"); + f.rank_expr("at", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]"); + f.rank_expr("bt", "tensor<float>(a[1],b[1]):[[9]]"); + f.verify_valid({"onnxModel(mapped).result"}); +} + //----------------------------------------------------------------------------- TEST_F("cleanup files", Setup()) { diff --git a/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.cpp b/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.cpp index 5743a3d44d6..013f359c4f9 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.cpp @@ -131,13 +131,10 @@ IndexEnvironment::hintFieldAccess(uint32_t ) const { } void IndexEnvironment::hintAttributeAccess(const string &) const { } -std::optional<vespalib::string> -IndexEnvironment::getOnnxModelFullPath(const vespalib::string &name) const +const search::fef::OnnxModel * +IndexEnvironment::getOnnxModel(const vespalib::string &name) const { - if (const auto model = _onnxModels.getModel(name)) { - return model->filePath; - } - return std::nullopt; + return _onnxModels.getModel(name); } IndexEnvironment::~IndexEnvironment() = default; diff --git a/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.h b/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.h index d0e9a516cd0..ad51eb17b4d 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.h +++ b/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.h @@ -69,7 +69,7 @@ public: return _constantValueRepo.getConstant(name); } - std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const override; + const search::fef::OnnxModel *getOnnxModel(const vespalib::string &name) const override; ~IndexEnvironment() override; }; diff --git a/searchcore/src/vespa/searchcore/proton/matching/onnx_models.cpp b/searchcore/src/vespa/searchcore/proton/matching/onnx_models.cpp index bdcf3e21d8e..ed80ca28bd6 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/onnx_models.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/onnx_models.cpp @@ -1,25 +1,10 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "onnx_models.h" +#include <assert.h> namespace proton::matching { -OnnxModels::Model::Model(const vespalib::string &name_in, - const vespalib::string &filePath_in) - : name(name_in), - filePath(filePath_in) -{ -} - -OnnxModels::Model::~Model() = default; - -bool -OnnxModels::Model::operator==(const Model &rhs) const -{ - return (name == rhs.name) && - (filePath == rhs.filePath); -} - OnnxModels::OnnxModels() : _models() { @@ -30,15 +15,15 @@ OnnxModels::~OnnxModels() = default; OnnxModels::OnnxModels(const Vector &models) : _models() { - for (const auto &model : models) { - _models.insert(std::make_pair(model.name, model)); + for (const auto &model: models) { + _models.emplace(model.name(), model); } } bool OnnxModels::operator==(const OnnxModels &rhs) const { - return _models == rhs._models; + return (_models == rhs._models); } const OnnxModels::Model * @@ -51,4 +36,16 @@ OnnxModels::getModel(const vespalib::string &name) const return nullptr; } +void +OnnxModels::configure(const ModelConfig &config, Model &model) +{ + assert(config.name == model.name()); + for (const auto &input: config.input) { + model.input_feature(input.name, input.source); + } + for (const auto &output: config.output) { + model.output_name(output.name, output.as); + } +} + } diff --git a/searchcore/src/vespa/searchcore/proton/matching/onnx_models.h b/searchcore/src/vespa/searchcore/proton/matching/onnx_models.h index fdaae657711..65ba524d8fc 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/onnx_models.h +++ b/searchcore/src/vespa/searchcore/proton/matching/onnx_models.h @@ -3,6 +3,8 @@ #pragma once #include <vespa/vespalib/stllike/string.h> +#include <vespa/searchlib/fef/onnx_model.h> +#include <vespa/searchcore/config/config-onnx-models.h> #include <map> #include <vector> @@ -14,16 +16,8 @@ namespace proton::matching { */ class OnnxModels { public: - struct Model { - vespalib::string name; - vespalib::string filePath; - - Model(const vespalib::string &name_in, - const vespalib::string &filePath_in); - ~Model(); - bool operator==(const Model &rhs) const; - }; - + using ModelConfig = vespa::config::search::core::OnnxModelsConfig::Model; + using Model = search::fef::OnnxModel; using Vector = std::vector<Model>; private: @@ -38,6 +32,7 @@ public: bool operator==(const OnnxModels &rhs) const; const Model *getModel(const vespalib::string &name) const; size_t size() const { return _models.size(); } + static void configure(const ModelConfig &config, Model &model); }; } diff --git a/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp b/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp index a8996abc856..c8b701e82f8 100644 --- a/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp @@ -321,6 +321,7 @@ DocumentDBConfigManager::update(const ConfigSnapshot &snapshot) LOG(info, "Got file path from file acquirer: '%s' (name='%s', ref='%s')", filePath.c_str(), rc.name.c_str(), rc.fileref.c_str()); models.emplace_back(rc.name, filePath); + OnnxModels::configure(rc, models.back()); } } newOnnxModels = std::make_shared<OnnxModels>(models); diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp index b49d9c365de..6a1e4ef9fa1 100644 --- a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp +++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp @@ -58,8 +58,8 @@ struct OnnxFeatureTest : ::testing::Test { vespalib::string expr_name = feature_name + ".rankingScript"; indexEnv.getProperties().add(expr_name, expr); } - void add_onnx(const vespalib::string &name, const vespalib::string &file) { - indexEnv.addOnnxModel(name, file); + void add_onnx(const OnnxModel &model) { + indexEnv.addOnnxModel(model); } void compile(const vespalib::string &seed) { resolver->addSeed(seed); @@ -89,7 +89,7 @@ TEST_F(OnnxFeatureTest, simple_onnx_model_can_be_calculated) { add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]"); add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]"); add_expr("bias_tensor", "tensor<float>(a[1],b[1]):[[9]]"); - add_onnx("simple", simple_model); + add_onnx(OnnxModel("simple", simple_model)); compile(onnx_feature("simple")); EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); EXPECT_EQ(get("onnxModel(simple).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); @@ -101,7 +101,7 @@ TEST_F(OnnxFeatureTest, dynamic_onnx_model_can_be_calculated) { add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]"); add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]"); add_expr("bias_tensor", "tensor<float>(a[1],b[2]):[[4,5]]"); - add_onnx("dynamic", dynamic_model); + add_onnx(OnnxModel("dynamic", dynamic_model)); compile(onnx_feature("dynamic")); EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); EXPECT_EQ(get("onnxModel(dynamic).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); @@ -112,7 +112,7 @@ TEST_F(OnnxFeatureTest, dynamic_onnx_model_can_be_calculated) { TEST_F(OnnxFeatureTest, strange_input_and_output_names_are_normalized) { add_expr("input_0", "tensor<float>(a[2]):[10,20]"); add_expr("input_1", "tensor<float>(a[2]):[5,10]"); - add_onnx("strange_names", strange_names_model); + add_onnx(OnnxModel("strange_names", strange_names_model)); compile(onnx_feature("strange_names")); auto expect_add = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},15).add({{"d0",1}},30); auto expect_sub = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},5).add({{"d0",1}},10); @@ -121,4 +121,20 @@ TEST_F(OnnxFeatureTest, strange_input_and_output_names_are_normalized) { EXPECT_EQ(get("onnxModel(strange_names)._baz_0", 1), expect_sub); } +TEST_F(OnnxFeatureTest, input_features_and_output_names_can_be_specified) { + add_expr("my_first_input", "tensor<float>(a[2]):[10,20]"); + add_expr("my_second_input", "tensor<float>(a[2]):[5,10]"); + add_onnx(OnnxModel("custom_names", strange_names_model) + .input_feature("input:0", "rankingExpression(my_first_input)") + .input_feature("input/1", "rankingExpression(my_second_input)") + .output_name("foo/bar", "my_first_output") + .output_name("-baz:0", "my_second_output")); + compile(onnx_feature("custom_names")); + auto expect_add = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},15).add({{"d0",1}},30); + auto expect_sub = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},5).add({{"d0",1}},10); + EXPECT_EQ(get(1), expect_add); + EXPECT_EQ(get("onnxModel(custom_names).my_first_output", 1), expect_add); + EXPECT_EQ(get("onnxModel(custom_names).my_second_output", 1), expect_sub); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp index 698d2309e5a..fca8988ba36 100644 --- a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp +++ b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp @@ -2,6 +2,7 @@ #include "onnx_feature.h" #include <vespa/searchlib/fef/properties.h> +#include <vespa/searchlib/fef/onnx_model.h> #include <vespa/searchlib/fef/featureexecutor.h> #include <vespa/eval/tensor/dense/dense_tensor_view.h> #include <vespa/eval/tensor/dense/mutable_dense_tensor_view.h> @@ -85,37 +86,45 @@ OnnxBlueprint::setup(const IIndexEnvironment &env, auto optimize = (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP) ? Onnx::Optimize::DISABLE : Onnx::Optimize::ENABLE; - auto file_name = env.getOnnxModelFullPath(params[0].getValue()); - if (!file_name.has_value()) { + auto model_cfg = env.getOnnxModel(params[0].getValue()); + if (!model_cfg) { return fail("no model with name '%s' found", params[0].getValue().c_str()); } try { - _model = std::make_unique<Onnx>(file_name.value(), optimize); + _model = std::make_unique<Onnx>(model_cfg->file_path(), optimize); } catch (std::exception &ex) { return fail("model setup failed: %s", ex.what()); } Onnx::WirePlanner planner; for (size_t i = 0; i < _model->inputs().size(); ++i) { const auto &model_input = _model->inputs()[i]; - vespalib::string input_name = normalize_name(model_input.name, "input"); - if (auto maybe_input = defineInput(fmt("rankingExpression(\"%s\")", input_name.c_str()), AcceptInput::OBJECT)) { + auto input_feature = model_cfg->input_feature(model_input.name); + if (!input_feature.has_value()) { + input_feature = fmt("rankingExpression(\"%s\")", normalize_name(model_input.name, "input").c_str()); + } + if (auto maybe_input = defineInput(input_feature.value(), AcceptInput::OBJECT)) { const FeatureType &feature_input = maybe_input.value(); assert(feature_input.is_object()); if (!planner.bind_input_type(feature_input.type(), model_input)) { - return fail("incompatible type for input '%s': %s -> %s", input_name.c_str(), + return fail("incompatible type for input (%s -> %s): %s -> %s", + input_feature.value().c_str(), model_input.name.c_str(), feature_input.type().to_spec().c_str(), model_input.type_as_string().c_str()); } } } for (size_t i = 0; i < _model->outputs().size(); ++i) { const auto &model_output = _model->outputs()[i]; - vespalib::string output_name = normalize_name(model_output.name, "output"); + auto output_name = model_cfg->output_name(model_output.name); + if (!output_name.has_value()) { + output_name = normalize_name(model_output.name, "output"); + } ValueType output_type = planner.make_output_type(model_output); if (output_type.is_error()) { - return fail("unable to make compatible type for output '%s': %s -> error", - output_name.c_str(), model_output.type_as_string().c_str()); + return fail("unable to make compatible type for output (%s -> %s): %s -> error", + model_output.name.c_str(), output_name.value().c_str(), + model_output.type_as_string().c_str()); } - describeOutput(output_name, "output from onnx model", FeatureType::object(output_type)); + describeOutput(output_name.value(), "output from onnx model", FeatureType::object(output_type)); } _wire_info = planner.get_wire_info(*_model); return true; diff --git a/searchlib/src/vespa/searchlib/fef/CMakeLists.txt b/searchlib/src/vespa/searchlib/fef/CMakeLists.txt index 178de1b8b87..d6f8764cd63 100644 --- a/searchlib/src/vespa/searchlib/fef/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/fef/CMakeLists.txt @@ -4,12 +4,12 @@ vespa_add_library(searchlib_fef OBJECT blueprint.cpp blueprintfactory.cpp blueprintresolver.cpp + feature_resolver.cpp feature_type.cpp featureexecutor.cpp featurenamebuilder.cpp featurenameparser.cpp featureoverrider.cpp - feature_resolver.cpp fef.cpp fieldinfo.cpp fieldpositionsiterator.cpp @@ -19,6 +19,7 @@ vespa_add_library(searchlib_fef OBJECT matchdata.cpp matchdatalayout.cpp objectstore.cpp + onnx_model.cpp parameter.cpp parameterdescriptions.cpp parametervalidator.cpp diff --git a/searchlib/src/vespa/searchlib/fef/iindexenvironment.h b/searchlib/src/vespa/searchlib/fef/iindexenvironment.h index 26e88a98033..384b81643cc 100644 --- a/searchlib/src/vespa/searchlib/fef/iindexenvironment.h +++ b/searchlib/src/vespa/searchlib/fef/iindexenvironment.h @@ -3,7 +3,6 @@ #pragma once #include <vespa/vespalib/stllike/string.h> -#include <optional> namespace vespalib::eval { struct ConstantValue; } @@ -12,6 +11,7 @@ namespace search::fef { class Properties; class FieldInfo; class ITableManager; +class OnnxModel; /** * Abstract view of index related information available to the @@ -122,9 +122,9 @@ public: virtual std::unique_ptr<vespalib::eval::ConstantValue> getConstantValue(const vespalib::string &name) const = 0; /** - * Get the full path of the file containing the given onnx model + * Get configuration for the given onnx model. **/ - virtual std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const = 0; + virtual const OnnxModel *getOnnxModel(const vespalib::string &name) const = 0; virtual uint32_t getDistributionKey() const = 0; diff --git a/searchlib/src/vespa/searchlib/fef/onnx_model.cpp b/searchlib/src/vespa/searchlib/fef/onnx_model.cpp new file mode 100644 index 00000000000..ba5adaae857 --- /dev/null +++ b/searchlib/src/vespa/searchlib/fef/onnx_model.cpp @@ -0,0 +1,55 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "onnx_model.h" +#include <tuple> + +namespace search::fef { + +OnnxModel::OnnxModel(const vespalib::string &name_in, + const vespalib::string &file_path_in) + : _name(name_in), + _file_path(file_path_in), + _input_features(), + _output_names() +{ +} + +OnnxModel & +OnnxModel::input_feature(const vespalib::string &model_input_name, const vespalib::string &input_feature) { + _input_features[model_input_name] = input_feature; + return *this; +} + +OnnxModel & +OnnxModel::output_name(const vespalib::string &model_output_name, const vespalib::string &output_name) { + _output_names[model_output_name] = output_name; + return *this; +} + +std::optional<vespalib::string> +OnnxModel::input_feature(const vespalib::string &model_input_name) const { + auto pos = _input_features.find(model_input_name); + if (pos != _input_features.end()) { + return pos->second; + } + return std::nullopt; +} + +std::optional<vespalib::string> +OnnxModel::output_name(const vespalib::string &model_output_name) const { + auto pos = _output_names.find(model_output_name); + if (pos != _output_names.end()) { + return pos->second; + } + return std::nullopt; +} + +bool +OnnxModel::operator==(const OnnxModel &rhs) const { + return (std::tie(_name, _file_path, _input_features, _output_names) == + std::tie(rhs._name, rhs._file_path, rhs._input_features, rhs._output_names)); +} + +OnnxModel::~OnnxModel() = default; + +} diff --git a/searchlib/src/vespa/searchlib/fef/onnx_model.h b/searchlib/src/vespa/searchlib/fef/onnx_model.h new file mode 100644 index 00000000000..2195a50600d --- /dev/null +++ b/searchlib/src/vespa/searchlib/fef/onnx_model.h @@ -0,0 +1,39 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/vespalib/stllike/string.h> +#include <optional> +#include <map> + +namespace search::fef { + +/** + * Class containing configuration for a single onnx model setup. This + * class is used both by the IIndexEnvironment api as well as the + * OnnxModels config adapter in the search core (matching component). + **/ +class OnnxModel { +private: + vespalib::string _name; + vespalib::string _file_path; + std::map<vespalib::string,vespalib::string> _input_features; + std::map<vespalib::string,vespalib::string> _output_names; + +public: + OnnxModel(const vespalib::string &name_in, + const vespalib::string &file_path_in); + ~OnnxModel(); + + const vespalib::string &name() const { return _name; } + const vespalib::string &file_path() const { return _file_path; } + OnnxModel &input_feature(const vespalib::string &model_input_name, const vespalib::string &input_feature); + OnnxModel &output_name(const vespalib::string &model_output_name, const vespalib::string &output_name); + std::optional<vespalib::string> input_feature(const vespalib::string &model_input_name) const; + std::optional<vespalib::string> output_name(const vespalib::string &model_output_name) const; + bool operator==(const OnnxModel &rhs) const; + const std::map<vespalib::string,vespalib::string> &inspect_input_features() const { return _input_features; } + const std::map<vespalib::string,vespalib::string> &inspect_output_names() const { return _output_names; } +}; + +} diff --git a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp index 6e2e0b88fbb..d2d336dcdc8 100644 --- a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp +++ b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp @@ -54,20 +54,20 @@ IndexEnvironment::addConstantValue(const vespalib::string &name, (void) insertRes; } -std::optional<vespalib::string> -IndexEnvironment::getOnnxModelFullPath(const vespalib::string &name) const +const OnnxModel * +IndexEnvironment::getOnnxModel(const vespalib::string &name) const { auto pos = _models.find(name); if (pos != _models.end()) { - return pos->second; + return &pos->second; } - return std::nullopt; + return nullptr; } void -IndexEnvironment::addOnnxModel(const vespalib::string &name, const vespalib::string &path) +IndexEnvironment::addOnnxModel(const OnnxModel &model) { - _models[name] = path; + _models.insert_or_assign(model.name(), model); } diff --git a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h index 6602d9f8ee9..0d8d0091921 100644 --- a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h +++ b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h @@ -5,6 +5,7 @@ #include <vespa/searchlib/attribute/attributemanager.h> #include <vespa/searchlib/fef/iindexenvironment.h> #include <vespa/searchlib/fef/properties.h> +#include <vespa/searchlib/fef/onnx_model.h> #include <vespa/searchlib/fef/fieldinfo.h> #include <vespa/searchlib/fef/tablemanager.h> #include <vespa/eval/eval/value_cache/constant_value.h> @@ -47,7 +48,7 @@ public: }; using ConstantsMap = std::map<vespalib::string, Constant>; - using ModelMap = std::map<vespalib::string, vespalib::string>; + using ModelMap = std::map<vespalib::string, OnnxModel>; IndexEnvironment(); ~IndexEnvironment(); @@ -84,8 +85,8 @@ public: vespalib::eval::ValueType type, std::unique_ptr<vespalib::eval::Value> value); - std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const override; - void addOnnxModel(const vespalib::string &name, const vespalib::string &path); + const OnnxModel *getOnnxModel(const vespalib::string &name) const override; + void addOnnxModel(const OnnxModel &model); private: IndexEnvironment(const IndexEnvironment &); // hide diff --git a/streamingvisitors/src/vespa/searchvisitor/indexenvironment.h b/streamingvisitors/src/vespa/searchvisitor/indexenvironment.h index 3bbfb0b23f9..dc7be36c290 100644 --- a/streamingvisitors/src/vespa/searchvisitor/indexenvironment.h +++ b/streamingvisitors/src/vespa/searchvisitor/indexenvironment.h @@ -73,8 +73,8 @@ public: return vespalib::eval::ConstantValue::UP(); } - std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &) const override { - return std::nullopt; + const search::fef::OnnxModel *getOnnxModel(const vespalib::string &) const override { + return nullptr; } bool addField(const vespalib::string & name, bool isAttribute); |