summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-08-31 10:32:50 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-09-01 12:45:13 +0000
commit3440f424ab11d8d8810b6f9785e6a8fad7271fe1 (patch)
treeb769b0acf1485b4b18fd51efe556c8a6b583c7bf /searchlib
parent3b8f7fdff4872bd010286753f6072ec492f14a48 (diff)
handle onnx models config
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp4
-rw-r--r--searchlib/src/vespa/searchlib/features/onnx_feature.cpp13
-rw-r--r--searchlib/src/vespa/searchlib/fef/iindexenvironment.h6
-rw-r--r--searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp17
-rw-r--r--searchlib/src/vespa/searchlib/fef/test/indexenvironment.h5
5 files changed, 35 insertions, 10 deletions
diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
index 7a200a46ab2..826984832f6 100644
--- a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
+++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
@@ -58,9 +58,7 @@ struct OnnxFeatureTest : ::testing::Test {
indexEnv.getProperties().add(expr_name, expr);
}
void add_onnx(const vespalib::string &name, const vespalib::string &file) {
- vespalib::string feature_name = onnx_feature(name);
- vespalib::string file_name = feature_name + ".fileref";
- indexEnv.getProperties().add(file_name, file);
+ indexEnv.addOnnxModel(name, file);
}
void compile(const vespalib::string &seed) {
resolver->addSeed(seed);
diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
index 7433021b9b6..b24392ce629 100644
--- a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
+++ b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
@@ -66,15 +66,14 @@ OnnxBlueprint::setup(const IIndexEnvironment &env,
auto optimize = (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP)
? Onnx::Optimize::DISABLE
: Onnx::Optimize::ENABLE;
-
- // Note: Using the fileref property with the model name as
- // fallback to get a file name. This needs to be replaced with an
- // actual file reference obtained through config when available.
- vespalib::string file_name = env.getProperties().lookup(getName(), "fileref").get(params[0].getValue());
+ auto file_name = env.getOnnxModelFullPath(params[0].getValue());
+ if (!file_name.has_value()) {
+ return fail("no model with name '%s' found", params[0].getValue().c_str());
+ }
try {
- _model = std::make_unique<Onnx>(file_name, optimize);
+ _model = std::make_unique<Onnx>(file_name.value(), optimize);
} catch (std::exception &ex) {
- return fail("Model setup failed: %s", ex.what());
+ return fail("model setup failed: %s", ex.what());
}
Onnx::WirePlanner planner;
for (size_t i = 0; i < _model->inputs().size(); ++i) {
diff --git a/searchlib/src/vespa/searchlib/fef/iindexenvironment.h b/searchlib/src/vespa/searchlib/fef/iindexenvironment.h
index bdeead3e852..26e88a98033 100644
--- a/searchlib/src/vespa/searchlib/fef/iindexenvironment.h
+++ b/searchlib/src/vespa/searchlib/fef/iindexenvironment.h
@@ -3,6 +3,7 @@
#pragma once
#include <vespa/vespalib/stllike/string.h>
+#include <optional>
namespace vespalib::eval { struct ConstantValue; }
@@ -120,6 +121,11 @@ public:
*/
virtual std::unique_ptr<vespalib::eval::ConstantValue> getConstantValue(const vespalib::string &name) const = 0;
+ /**
+ * Get the full path of the file containing the given onnx model
+ **/
+ virtual std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const = 0;
+
virtual uint32_t getDistributionKey() const = 0;
/**
diff --git a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp
index e998e4d18bd..6e2e0b88fbb 100644
--- a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp
+++ b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp
@@ -54,4 +54,21 @@ IndexEnvironment::addConstantValue(const vespalib::string &name,
(void) insertRes;
}
+std::optional<vespalib::string>
+IndexEnvironment::getOnnxModelFullPath(const vespalib::string &name) const
+{
+ auto pos = _models.find(name);
+ if (pos != _models.end()) {
+ return pos->second;
+ }
+ return std::nullopt;
+}
+
+void
+IndexEnvironment::addOnnxModel(const vespalib::string &name, const vespalib::string &path)
+{
+ _models[name] = path;
+}
+
+
}
diff --git a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h
index d84cebc7f52..6602d9f8ee9 100644
--- a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h
+++ b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h
@@ -47,6 +47,7 @@ public:
};
using ConstantsMap = std::map<vespalib::string, Constant>;
+ using ModelMap = std::map<vespalib::string, vespalib::string>;
IndexEnvironment();
~IndexEnvironment();
@@ -83,6 +84,9 @@ public:
vespalib::eval::ValueType type,
std::unique_ptr<vespalib::eval::Value> value);
+ std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const override;
+ void addOnnxModel(const vespalib::string &name, const vespalib::string &path);
+
private:
IndexEnvironment(const IndexEnvironment &); // hide
IndexEnvironment & operator=(const IndexEnvironment &); // hide
@@ -93,6 +97,7 @@ private:
AttributeMap _attrMap;
TableManager _tableMan;
ConstantsMap _constants;
+ ModelMap _models;
};
}