diff options
author | Håvard Pettersen <havardpe@oath.com> | 2020-09-22 08:28:02 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2020-09-22 10:15:54 +0000 |
commit | 9a66f21d375e4fa07a96069644615581e55129d5 (patch) | |
tree | 06891b9c7f91fd0cf673cd1c540cb0dcafd5e950 /searchlib | |
parent | 804e8057c2eca61ec9bc8985430613e0731922a2 (diff) |
handle onnx model config for inputs and outputs
Diffstat (limited to 'searchlib')
8 files changed, 149 insertions, 28 deletions
diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp index b49d9c365de..6a1e4ef9fa1 100644 --- a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp +++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp @@ -58,8 +58,8 @@ struct OnnxFeatureTest : ::testing::Test { vespalib::string expr_name = feature_name + ".rankingScript"; indexEnv.getProperties().add(expr_name, expr); } - void add_onnx(const vespalib::string &name, const vespalib::string &file) { - indexEnv.addOnnxModel(name, file); + void add_onnx(const OnnxModel &model) { + indexEnv.addOnnxModel(model); } void compile(const vespalib::string &seed) { resolver->addSeed(seed); @@ -89,7 +89,7 @@ TEST_F(OnnxFeatureTest, simple_onnx_model_can_be_calculated) { add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]"); add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]"); add_expr("bias_tensor", "tensor<float>(a[1],b[1]):[[9]]"); - add_onnx("simple", simple_model); + add_onnx(OnnxModel("simple", simple_model)); compile(onnx_feature("simple")); EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); EXPECT_EQ(get("onnxModel(simple).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); @@ -101,7 +101,7 @@ TEST_F(OnnxFeatureTest, dynamic_onnx_model_can_be_calculated) { add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]"); add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]"); add_expr("bias_tensor", "tensor<float>(a[1],b[2]):[[4,5]]"); - add_onnx("dynamic", dynamic_model); + add_onnx(OnnxModel("dynamic", dynamic_model)); compile(onnx_feature("dynamic")); EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); EXPECT_EQ(get("onnxModel(dynamic).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); @@ -112,7 +112,7 @@ TEST_F(OnnxFeatureTest, dynamic_onnx_model_can_be_calculated) { TEST_F(OnnxFeatureTest, strange_input_and_output_names_are_normalized) { add_expr("input_0", "tensor<float>(a[2]):[10,20]"); add_expr("input_1", "tensor<float>(a[2]):[5,10]"); - add_onnx("strange_names", strange_names_model); + add_onnx(OnnxModel("strange_names", strange_names_model)); compile(onnx_feature("strange_names")); auto expect_add = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},15).add({{"d0",1}},30); auto expect_sub = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},5).add({{"d0",1}},10); @@ -121,4 +121,20 @@ TEST_F(OnnxFeatureTest, strange_input_and_output_names_are_normalized) { EXPECT_EQ(get("onnxModel(strange_names)._baz_0", 1), expect_sub); } +TEST_F(OnnxFeatureTest, input_features_and_output_names_can_be_specified) { + add_expr("my_first_input", "tensor<float>(a[2]):[10,20]"); + add_expr("my_second_input", "tensor<float>(a[2]):[5,10]"); + add_onnx(OnnxModel("custom_names", strange_names_model) + .input_feature("input:0", "rankingExpression(my_first_input)") + .input_feature("input/1", "rankingExpression(my_second_input)") + .output_name("foo/bar", "my_first_output") + .output_name("-baz:0", "my_second_output")); + compile(onnx_feature("custom_names")); + auto expect_add = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},15).add({{"d0",1}},30); + auto expect_sub = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},5).add({{"d0",1}},10); + EXPECT_EQ(get(1), expect_add); + EXPECT_EQ(get("onnxModel(custom_names).my_first_output", 1), expect_add); + EXPECT_EQ(get("onnxModel(custom_names).my_second_output", 1), expect_sub); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp index 698d2309e5a..fca8988ba36 100644 --- a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp +++ b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp @@ -2,6 +2,7 @@ #include "onnx_feature.h" #include <vespa/searchlib/fef/properties.h> +#include <vespa/searchlib/fef/onnx_model.h> #include <vespa/searchlib/fef/featureexecutor.h> #include <vespa/eval/tensor/dense/dense_tensor_view.h> #include <vespa/eval/tensor/dense/mutable_dense_tensor_view.h> @@ -85,37 +86,45 @@ OnnxBlueprint::setup(const IIndexEnvironment &env, auto optimize = (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP) ? Onnx::Optimize::DISABLE : Onnx::Optimize::ENABLE; - auto file_name = env.getOnnxModelFullPath(params[0].getValue()); - if (!file_name.has_value()) { + auto model_cfg = env.getOnnxModel(params[0].getValue()); + if (!model_cfg) { return fail("no model with name '%s' found", params[0].getValue().c_str()); } try { - _model = std::make_unique<Onnx>(file_name.value(), optimize); + _model = std::make_unique<Onnx>(model_cfg->file_path(), optimize); } catch (std::exception &ex) { return fail("model setup failed: %s", ex.what()); } Onnx::WirePlanner planner; for (size_t i = 0; i < _model->inputs().size(); ++i) { const auto &model_input = _model->inputs()[i]; - vespalib::string input_name = normalize_name(model_input.name, "input"); - if (auto maybe_input = defineInput(fmt("rankingExpression(\"%s\")", input_name.c_str()), AcceptInput::OBJECT)) { + auto input_feature = model_cfg->input_feature(model_input.name); + if (!input_feature.has_value()) { + input_feature = fmt("rankingExpression(\"%s\")", normalize_name(model_input.name, "input").c_str()); + } + if (auto maybe_input = defineInput(input_feature.value(), AcceptInput::OBJECT)) { const FeatureType &feature_input = maybe_input.value(); assert(feature_input.is_object()); if (!planner.bind_input_type(feature_input.type(), model_input)) { - return fail("incompatible type for input '%s': %s -> %s", input_name.c_str(), + return fail("incompatible type for input (%s -> %s): %s -> %s", + input_feature.value().c_str(), model_input.name.c_str(), feature_input.type().to_spec().c_str(), model_input.type_as_string().c_str()); } } } for (size_t i = 0; i < _model->outputs().size(); ++i) { const auto &model_output = _model->outputs()[i]; - vespalib::string output_name = normalize_name(model_output.name, "output"); + auto output_name = model_cfg->output_name(model_output.name); + if (!output_name.has_value()) { + output_name = normalize_name(model_output.name, "output"); + } ValueType output_type = planner.make_output_type(model_output); if (output_type.is_error()) { - return fail("unable to make compatible type for output '%s': %s -> error", - output_name.c_str(), model_output.type_as_string().c_str()); + return fail("unable to make compatible type for output (%s -> %s): %s -> error", + model_output.name.c_str(), output_name.value().c_str(), + model_output.type_as_string().c_str()); } - describeOutput(output_name, "output from onnx model", FeatureType::object(output_type)); + describeOutput(output_name.value(), "output from onnx model", FeatureType::object(output_type)); } _wire_info = planner.get_wire_info(*_model); return true; diff --git a/searchlib/src/vespa/searchlib/fef/CMakeLists.txt b/searchlib/src/vespa/searchlib/fef/CMakeLists.txt index 178de1b8b87..d6f8764cd63 100644 --- a/searchlib/src/vespa/searchlib/fef/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/fef/CMakeLists.txt @@ -4,12 +4,12 @@ vespa_add_library(searchlib_fef OBJECT blueprint.cpp blueprintfactory.cpp blueprintresolver.cpp + feature_resolver.cpp feature_type.cpp featureexecutor.cpp featurenamebuilder.cpp featurenameparser.cpp featureoverrider.cpp - feature_resolver.cpp fef.cpp fieldinfo.cpp fieldpositionsiterator.cpp @@ -19,6 +19,7 @@ vespa_add_library(searchlib_fef OBJECT matchdata.cpp matchdatalayout.cpp objectstore.cpp + onnx_model.cpp parameter.cpp parameterdescriptions.cpp parametervalidator.cpp diff --git a/searchlib/src/vespa/searchlib/fef/iindexenvironment.h b/searchlib/src/vespa/searchlib/fef/iindexenvironment.h index 26e88a98033..384b81643cc 100644 --- a/searchlib/src/vespa/searchlib/fef/iindexenvironment.h +++ b/searchlib/src/vespa/searchlib/fef/iindexenvironment.h @@ -3,7 +3,6 @@ #pragma once #include <vespa/vespalib/stllike/string.h> -#include <optional> namespace vespalib::eval { struct ConstantValue; } @@ -12,6 +11,7 @@ namespace search::fef { class Properties; class FieldInfo; class ITableManager; +class OnnxModel; /** * Abstract view of index related information available to the @@ -122,9 +122,9 @@ public: virtual std::unique_ptr<vespalib::eval::ConstantValue> getConstantValue(const vespalib::string &name) const = 0; /** - * Get the full path of the file containing the given onnx model + * Get configuration for the given onnx model. **/ - virtual std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const = 0; + virtual const OnnxModel *getOnnxModel(const vespalib::string &name) const = 0; virtual uint32_t getDistributionKey() const = 0; diff --git a/searchlib/src/vespa/searchlib/fef/onnx_model.cpp b/searchlib/src/vespa/searchlib/fef/onnx_model.cpp new file mode 100644 index 00000000000..ba5adaae857 --- /dev/null +++ b/searchlib/src/vespa/searchlib/fef/onnx_model.cpp @@ -0,0 +1,55 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "onnx_model.h" +#include <tuple> + +namespace search::fef { + +OnnxModel::OnnxModel(const vespalib::string &name_in, + const vespalib::string &file_path_in) + : _name(name_in), + _file_path(file_path_in), + _input_features(), + _output_names() +{ +} + +OnnxModel & +OnnxModel::input_feature(const vespalib::string &model_input_name, const vespalib::string &input_feature) { + _input_features[model_input_name] = input_feature; + return *this; +} + +OnnxModel & +OnnxModel::output_name(const vespalib::string &model_output_name, const vespalib::string &output_name) { + _output_names[model_output_name] = output_name; + return *this; +} + +std::optional<vespalib::string> +OnnxModel::input_feature(const vespalib::string &model_input_name) const { + auto pos = _input_features.find(model_input_name); + if (pos != _input_features.end()) { + return pos->second; + } + return std::nullopt; +} + +std::optional<vespalib::string> +OnnxModel::output_name(const vespalib::string &model_output_name) const { + auto pos = _output_names.find(model_output_name); + if (pos != _output_names.end()) { + return pos->second; + } + return std::nullopt; +} + +bool +OnnxModel::operator==(const OnnxModel &rhs) const { + return (std::tie(_name, _file_path, _input_features, _output_names) == + std::tie(rhs._name, rhs._file_path, rhs._input_features, rhs._output_names)); +} + +OnnxModel::~OnnxModel() = default; + +} diff --git a/searchlib/src/vespa/searchlib/fef/onnx_model.h b/searchlib/src/vespa/searchlib/fef/onnx_model.h new file mode 100644 index 00000000000..2195a50600d --- /dev/null +++ b/searchlib/src/vespa/searchlib/fef/onnx_model.h @@ -0,0 +1,39 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/vespalib/stllike/string.h> +#include <optional> +#include <map> + +namespace search::fef { + +/** + * Class containing configuration for a single onnx model setup. This + * class is used both by the IIndexEnvironment api as well as the + * OnnxModels config adapter in the search core (matching component). + **/ +class OnnxModel { +private: + vespalib::string _name; + vespalib::string _file_path; + std::map<vespalib::string,vespalib::string> _input_features; + std::map<vespalib::string,vespalib::string> _output_names; + +public: + OnnxModel(const vespalib::string &name_in, + const vespalib::string &file_path_in); + ~OnnxModel(); + + const vespalib::string &name() const { return _name; } + const vespalib::string &file_path() const { return _file_path; } + OnnxModel &input_feature(const vespalib::string &model_input_name, const vespalib::string &input_feature); + OnnxModel &output_name(const vespalib::string &model_output_name, const vespalib::string &output_name); + std::optional<vespalib::string> input_feature(const vespalib::string &model_input_name) const; + std::optional<vespalib::string> output_name(const vespalib::string &model_output_name) const; + bool operator==(const OnnxModel &rhs) const; + const std::map<vespalib::string,vespalib::string> &inspect_input_features() const { return _input_features; } + const std::map<vespalib::string,vespalib::string> &inspect_output_names() const { return _output_names; } +}; + +} diff --git a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp index 6e2e0b88fbb..d2d336dcdc8 100644 --- a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp +++ b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp @@ -54,20 +54,20 @@ IndexEnvironment::addConstantValue(const vespalib::string &name, (void) insertRes; } -std::optional<vespalib::string> -IndexEnvironment::getOnnxModelFullPath(const vespalib::string &name) const +const OnnxModel * +IndexEnvironment::getOnnxModel(const vespalib::string &name) const { auto pos = _models.find(name); if (pos != _models.end()) { - return pos->second; + return &pos->second; } - return std::nullopt; + return nullptr; } void -IndexEnvironment::addOnnxModel(const vespalib::string &name, const vespalib::string &path) +IndexEnvironment::addOnnxModel(const OnnxModel &model) { - _models[name] = path; + _models.insert_or_assign(model.name(), model); } diff --git a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h index 6602d9f8ee9..0d8d0091921 100644 --- a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h +++ b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h @@ -5,6 +5,7 @@ #include <vespa/searchlib/attribute/attributemanager.h> #include <vespa/searchlib/fef/iindexenvironment.h> #include <vespa/searchlib/fef/properties.h> +#include <vespa/searchlib/fef/onnx_model.h> #include <vespa/searchlib/fef/fieldinfo.h> #include <vespa/searchlib/fef/tablemanager.h> #include <vespa/eval/eval/value_cache/constant_value.h> @@ -47,7 +48,7 @@ public: }; using ConstantsMap = std::map<vespalib::string, Constant>; - using ModelMap = std::map<vespalib::string, vespalib::string>; + using ModelMap = std::map<vespalib::string, OnnxModel>; IndexEnvironment(); ~IndexEnvironment(); @@ -84,8 +85,8 @@ public: vespalib::eval::ValueType type, std::unique_ptr<vespalib::eval::Value> value); - std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const override; - void addOnnxModel(const vespalib::string &name, const vespalib::string &path); + const OnnxModel *getOnnxModel(const vespalib::string &name) const override; + void addOnnxModel(const OnnxModel &model); private: IndexEnvironment(const IndexEnvironment &); // hide |