aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-09-22 08:28:02 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-09-22 10:15:54 +0000
commit9a66f21d375e4fa07a96069644615581e55129d5 (patch)
tree06891b9c7f91fd0cf673cd1c540cb0dcafd5e950 /searchlib
parent804e8057c2eca61ec9bc8985430613e0731922a2 (diff)
handle onnx model config for inputs and outputs
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp26
-rw-r--r--searchlib/src/vespa/searchlib/features/onnx_feature.cpp29
-rw-r--r--searchlib/src/vespa/searchlib/fef/CMakeLists.txt3
-rw-r--r--searchlib/src/vespa/searchlib/fef/iindexenvironment.h6
-rw-r--r--searchlib/src/vespa/searchlib/fef/onnx_model.cpp55
-rw-r--r--searchlib/src/vespa/searchlib/fef/onnx_model.h39
-rw-r--r--searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp12
-rw-r--r--searchlib/src/vespa/searchlib/fef/test/indexenvironment.h7
8 files changed, 149 insertions, 28 deletions
diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
index b49d9c365de..6a1e4ef9fa1 100644
--- a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
+++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
@@ -58,8 +58,8 @@ struct OnnxFeatureTest : ::testing::Test {
vespalib::string expr_name = feature_name + ".rankingScript";
indexEnv.getProperties().add(expr_name, expr);
}
- void add_onnx(const vespalib::string &name, const vespalib::string &file) {
- indexEnv.addOnnxModel(name, file);
+ void add_onnx(const OnnxModel &model) {
+ indexEnv.addOnnxModel(model);
}
void compile(const vespalib::string &seed) {
resolver->addSeed(seed);
@@ -89,7 +89,7 @@ TEST_F(OnnxFeatureTest, simple_onnx_model_can_be_calculated) {
add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]");
add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]");
add_expr("bias_tensor", "tensor<float>(a[1],b[1]):[[9]]");
- add_onnx("simple", simple_model);
+ add_onnx(OnnxModel("simple", simple_model));
compile(onnx_feature("simple"));
EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
EXPECT_EQ(get("onnxModel(simple).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
@@ -101,7 +101,7 @@ TEST_F(OnnxFeatureTest, dynamic_onnx_model_can_be_calculated) {
add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]");
add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]");
add_expr("bias_tensor", "tensor<float>(a[1],b[2]):[[4,5]]");
- add_onnx("dynamic", dynamic_model);
+ add_onnx(OnnxModel("dynamic", dynamic_model));
compile(onnx_feature("dynamic"));
EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
EXPECT_EQ(get("onnxModel(dynamic).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
@@ -112,7 +112,7 @@ TEST_F(OnnxFeatureTest, dynamic_onnx_model_can_be_calculated) {
TEST_F(OnnxFeatureTest, strange_input_and_output_names_are_normalized) {
add_expr("input_0", "tensor<float>(a[2]):[10,20]");
add_expr("input_1", "tensor<float>(a[2]):[5,10]");
- add_onnx("strange_names", strange_names_model);
+ add_onnx(OnnxModel("strange_names", strange_names_model));
compile(onnx_feature("strange_names"));
auto expect_add = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},15).add({{"d0",1}},30);
auto expect_sub = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},5).add({{"d0",1}},10);
@@ -121,4 +121,20 @@ TEST_F(OnnxFeatureTest, strange_input_and_output_names_are_normalized) {
EXPECT_EQ(get("onnxModel(strange_names)._baz_0", 1), expect_sub);
}
+TEST_F(OnnxFeatureTest, input_features_and_output_names_can_be_specified) {
+ add_expr("my_first_input", "tensor<float>(a[2]):[10,20]");
+ add_expr("my_second_input", "tensor<float>(a[2]):[5,10]");
+ add_onnx(OnnxModel("custom_names", strange_names_model)
+ .input_feature("input:0", "rankingExpression(my_first_input)")
+ .input_feature("input/1", "rankingExpression(my_second_input)")
+ .output_name("foo/bar", "my_first_output")
+ .output_name("-baz:0", "my_second_output"));
+ compile(onnx_feature("custom_names"));
+ auto expect_add = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},15).add({{"d0",1}},30);
+ auto expect_sub = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},5).add({{"d0",1}},10);
+ EXPECT_EQ(get(1), expect_add);
+ EXPECT_EQ(get("onnxModel(custom_names).my_first_output", 1), expect_add);
+ EXPECT_EQ(get("onnxModel(custom_names).my_second_output", 1), expect_sub);
+}
+
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
index 698d2309e5a..fca8988ba36 100644
--- a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
+++ b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
@@ -2,6 +2,7 @@
#include "onnx_feature.h"
#include <vespa/searchlib/fef/properties.h>
+#include <vespa/searchlib/fef/onnx_model.h>
#include <vespa/searchlib/fef/featureexecutor.h>
#include <vespa/eval/tensor/dense/dense_tensor_view.h>
#include <vespa/eval/tensor/dense/mutable_dense_tensor_view.h>
@@ -85,37 +86,45 @@ OnnxBlueprint::setup(const IIndexEnvironment &env,
auto optimize = (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP)
? Onnx::Optimize::DISABLE
: Onnx::Optimize::ENABLE;
- auto file_name = env.getOnnxModelFullPath(params[0].getValue());
- if (!file_name.has_value()) {
+ auto model_cfg = env.getOnnxModel(params[0].getValue());
+ if (!model_cfg) {
return fail("no model with name '%s' found", params[0].getValue().c_str());
}
try {
- _model = std::make_unique<Onnx>(file_name.value(), optimize);
+ _model = std::make_unique<Onnx>(model_cfg->file_path(), optimize);
} catch (std::exception &ex) {
return fail("model setup failed: %s", ex.what());
}
Onnx::WirePlanner planner;
for (size_t i = 0; i < _model->inputs().size(); ++i) {
const auto &model_input = _model->inputs()[i];
- vespalib::string input_name = normalize_name(model_input.name, "input");
- if (auto maybe_input = defineInput(fmt("rankingExpression(\"%s\")", input_name.c_str()), AcceptInput::OBJECT)) {
+ auto input_feature = model_cfg->input_feature(model_input.name);
+ if (!input_feature.has_value()) {
+ input_feature = fmt("rankingExpression(\"%s\")", normalize_name(model_input.name, "input").c_str());
+ }
+ if (auto maybe_input = defineInput(input_feature.value(), AcceptInput::OBJECT)) {
const FeatureType &feature_input = maybe_input.value();
assert(feature_input.is_object());
if (!planner.bind_input_type(feature_input.type(), model_input)) {
- return fail("incompatible type for input '%s': %s -> %s", input_name.c_str(),
+ return fail("incompatible type for input (%s -> %s): %s -> %s",
+ input_feature.value().c_str(), model_input.name.c_str(),
feature_input.type().to_spec().c_str(), model_input.type_as_string().c_str());
}
}
}
for (size_t i = 0; i < _model->outputs().size(); ++i) {
const auto &model_output = _model->outputs()[i];
- vespalib::string output_name = normalize_name(model_output.name, "output");
+ auto output_name = model_cfg->output_name(model_output.name);
+ if (!output_name.has_value()) {
+ output_name = normalize_name(model_output.name, "output");
+ }
ValueType output_type = planner.make_output_type(model_output);
if (output_type.is_error()) {
- return fail("unable to make compatible type for output '%s': %s -> error",
- output_name.c_str(), model_output.type_as_string().c_str());
+ return fail("unable to make compatible type for output (%s -> %s): %s -> error",
+ model_output.name.c_str(), output_name.value().c_str(),
+ model_output.type_as_string().c_str());
}
- describeOutput(output_name, "output from onnx model", FeatureType::object(output_type));
+ describeOutput(output_name.value(), "output from onnx model", FeatureType::object(output_type));
}
_wire_info = planner.get_wire_info(*_model);
return true;
diff --git a/searchlib/src/vespa/searchlib/fef/CMakeLists.txt b/searchlib/src/vespa/searchlib/fef/CMakeLists.txt
index 178de1b8b87..d6f8764cd63 100644
--- a/searchlib/src/vespa/searchlib/fef/CMakeLists.txt
+++ b/searchlib/src/vespa/searchlib/fef/CMakeLists.txt
@@ -4,12 +4,12 @@ vespa_add_library(searchlib_fef OBJECT
blueprint.cpp
blueprintfactory.cpp
blueprintresolver.cpp
+ feature_resolver.cpp
feature_type.cpp
featureexecutor.cpp
featurenamebuilder.cpp
featurenameparser.cpp
featureoverrider.cpp
- feature_resolver.cpp
fef.cpp
fieldinfo.cpp
fieldpositionsiterator.cpp
@@ -19,6 +19,7 @@ vespa_add_library(searchlib_fef OBJECT
matchdata.cpp
matchdatalayout.cpp
objectstore.cpp
+ onnx_model.cpp
parameter.cpp
parameterdescriptions.cpp
parametervalidator.cpp
diff --git a/searchlib/src/vespa/searchlib/fef/iindexenvironment.h b/searchlib/src/vespa/searchlib/fef/iindexenvironment.h
index 26e88a98033..384b81643cc 100644
--- a/searchlib/src/vespa/searchlib/fef/iindexenvironment.h
+++ b/searchlib/src/vespa/searchlib/fef/iindexenvironment.h
@@ -3,7 +3,6 @@
#pragma once
#include <vespa/vespalib/stllike/string.h>
-#include <optional>
namespace vespalib::eval { struct ConstantValue; }
@@ -12,6 +11,7 @@ namespace search::fef {
class Properties;
class FieldInfo;
class ITableManager;
+class OnnxModel;
/**
* Abstract view of index related information available to the
@@ -122,9 +122,9 @@ public:
virtual std::unique_ptr<vespalib::eval::ConstantValue> getConstantValue(const vespalib::string &name) const = 0;
/**
- * Get the full path of the file containing the given onnx model
+ * Get configuration for the given onnx model.
**/
- virtual std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const = 0;
+ virtual const OnnxModel *getOnnxModel(const vespalib::string &name) const = 0;
virtual uint32_t getDistributionKey() const = 0;
diff --git a/searchlib/src/vespa/searchlib/fef/onnx_model.cpp b/searchlib/src/vespa/searchlib/fef/onnx_model.cpp
new file mode 100644
index 00000000000..ba5adaae857
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/fef/onnx_model.cpp
@@ -0,0 +1,55 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "onnx_model.h"
+#include <tuple>
+
+namespace search::fef {
+
+OnnxModel::OnnxModel(const vespalib::string &name_in,
+ const vespalib::string &file_path_in)
+ : _name(name_in),
+ _file_path(file_path_in),
+ _input_features(),
+ _output_names()
+{
+}
+
+OnnxModel &
+OnnxModel::input_feature(const vespalib::string &model_input_name, const vespalib::string &input_feature) {
+ _input_features[model_input_name] = input_feature;
+ return *this;
+}
+
+OnnxModel &
+OnnxModel::output_name(const vespalib::string &model_output_name, const vespalib::string &output_name) {
+ _output_names[model_output_name] = output_name;
+ return *this;
+}
+
+std::optional<vespalib::string>
+OnnxModel::input_feature(const vespalib::string &model_input_name) const {
+ auto pos = _input_features.find(model_input_name);
+ if (pos != _input_features.end()) {
+ return pos->second;
+ }
+ return std::nullopt;
+}
+
+std::optional<vespalib::string>
+OnnxModel::output_name(const vespalib::string &model_output_name) const {
+ auto pos = _output_names.find(model_output_name);
+ if (pos != _output_names.end()) {
+ return pos->second;
+ }
+ return std::nullopt;
+}
+
+bool
+OnnxModel::operator==(const OnnxModel &rhs) const {
+ return (std::tie(_name, _file_path, _input_features, _output_names) ==
+ std::tie(rhs._name, rhs._file_path, rhs._input_features, rhs._output_names));
+}
+
+OnnxModel::~OnnxModel() = default;
+
+}
diff --git a/searchlib/src/vespa/searchlib/fef/onnx_model.h b/searchlib/src/vespa/searchlib/fef/onnx_model.h
new file mode 100644
index 00000000000..2195a50600d
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/fef/onnx_model.h
@@ -0,0 +1,39 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <vespa/vespalib/stllike/string.h>
+#include <optional>
+#include <map>
+
+namespace search::fef {
+
+/**
+ * Class containing configuration for a single onnx model setup. This
+ * class is used both by the IIndexEnvironment api as well as the
+ * OnnxModels config adapter in the search core (matching component).
+ **/
+class OnnxModel {
+private:
+ vespalib::string _name;
+ vespalib::string _file_path;
+ std::map<vespalib::string,vespalib::string> _input_features;
+ std::map<vespalib::string,vespalib::string> _output_names;
+
+public:
+ OnnxModel(const vespalib::string &name_in,
+ const vespalib::string &file_path_in);
+ ~OnnxModel();
+
+ const vespalib::string &name() const { return _name; }
+ const vespalib::string &file_path() const { return _file_path; }
+ OnnxModel &input_feature(const vespalib::string &model_input_name, const vespalib::string &input_feature);
+ OnnxModel &output_name(const vespalib::string &model_output_name, const vespalib::string &output_name);
+ std::optional<vespalib::string> input_feature(const vespalib::string &model_input_name) const;
+ std::optional<vespalib::string> output_name(const vespalib::string &model_output_name) const;
+ bool operator==(const OnnxModel &rhs) const;
+ const std::map<vespalib::string,vespalib::string> &inspect_input_features() const { return _input_features; }
+ const std::map<vespalib::string,vespalib::string> &inspect_output_names() const { return _output_names; }
+};
+
+}
diff --git a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp
index 6e2e0b88fbb..d2d336dcdc8 100644
--- a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp
+++ b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp
@@ -54,20 +54,20 @@ IndexEnvironment::addConstantValue(const vespalib::string &name,
(void) insertRes;
}
-std::optional<vespalib::string>
-IndexEnvironment::getOnnxModelFullPath(const vespalib::string &name) const
+const OnnxModel *
+IndexEnvironment::getOnnxModel(const vespalib::string &name) const
{
auto pos = _models.find(name);
if (pos != _models.end()) {
- return pos->second;
+ return &pos->second;
}
- return std::nullopt;
+ return nullptr;
}
void
-IndexEnvironment::addOnnxModel(const vespalib::string &name, const vespalib::string &path)
+IndexEnvironment::addOnnxModel(const OnnxModel &model)
{
- _models[name] = path;
+ _models.insert_or_assign(model.name(), model);
}
diff --git a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h
index 6602d9f8ee9..0d8d0091921 100644
--- a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h
+++ b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h
@@ -5,6 +5,7 @@
#include <vespa/searchlib/attribute/attributemanager.h>
#include <vespa/searchlib/fef/iindexenvironment.h>
#include <vespa/searchlib/fef/properties.h>
+#include <vespa/searchlib/fef/onnx_model.h>
#include <vespa/searchlib/fef/fieldinfo.h>
#include <vespa/searchlib/fef/tablemanager.h>
#include <vespa/eval/eval/value_cache/constant_value.h>
@@ -47,7 +48,7 @@ public:
};
using ConstantsMap = std::map<vespalib::string, Constant>;
- using ModelMap = std::map<vespalib::string, vespalib::string>;
+ using ModelMap = std::map<vespalib::string, OnnxModel>;
IndexEnvironment();
~IndexEnvironment();
@@ -84,8 +85,8 @@ public:
vespalib::eval::ValueType type,
std::unique_ptr<vespalib::eval::Value> value);
- std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const override;
- void addOnnxModel(const vespalib::string &name, const vespalib::string &path);
+ const OnnxModel *getOnnxModel(const vespalib::string &name) const override;
+ void addOnnxModel(const OnnxModel &model);
private:
IndexEnvironment(const IndexEnvironment &); // hide