diff options
Diffstat (limited to 'searchlib')
5 files changed, 35 insertions, 10 deletions
diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp index 7a200a46ab2..826984832f6 100644 --- a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp +++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp @@ -58,9 +58,7 @@ struct OnnxFeatureTest : ::testing::Test { indexEnv.getProperties().add(expr_name, expr); } void add_onnx(const vespalib::string &name, const vespalib::string &file) { - vespalib::string feature_name = onnx_feature(name); - vespalib::string file_name = feature_name + ".fileref"; - indexEnv.getProperties().add(file_name, file); + indexEnv.addOnnxModel(name, file); } void compile(const vespalib::string &seed) { resolver->addSeed(seed); diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp index 7433021b9b6..b24392ce629 100644 --- a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp +++ b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp @@ -66,15 +66,14 @@ OnnxBlueprint::setup(const IIndexEnvironment &env, auto optimize = (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP) ? Onnx::Optimize::DISABLE : Onnx::Optimize::ENABLE; - - // Note: Using the fileref property with the model name as - // fallback to get a file name. This needs to be replaced with an - // actual file reference obtained through config when available. - vespalib::string file_name = env.getProperties().lookup(getName(), "fileref").get(params[0].getValue()); + auto file_name = env.getOnnxModelFullPath(params[0].getValue()); + if (!file_name.has_value()) { + return fail("no model with name '%s' found", params[0].getValue().c_str()); + } try { - _model = std::make_unique<Onnx>(file_name, optimize); + _model = std::make_unique<Onnx>(file_name.value(), optimize); } catch (std::exception &ex) { - return fail("Model setup failed: %s", ex.what()); + return fail("model setup failed: %s", ex.what()); } Onnx::WirePlanner planner; for (size_t i = 0; i < _model->inputs().size(); ++i) { diff --git a/searchlib/src/vespa/searchlib/fef/iindexenvironment.h b/searchlib/src/vespa/searchlib/fef/iindexenvironment.h index bdeead3e852..26e88a98033 100644 --- a/searchlib/src/vespa/searchlib/fef/iindexenvironment.h +++ b/searchlib/src/vespa/searchlib/fef/iindexenvironment.h @@ -3,6 +3,7 @@ #pragma once #include <vespa/vespalib/stllike/string.h> +#include <optional> namespace vespalib::eval { struct ConstantValue; } @@ -120,6 +121,11 @@ public: */ virtual std::unique_ptr<vespalib::eval::ConstantValue> getConstantValue(const vespalib::string &name) const = 0; + /** + * Get the full path of the file containing the given onnx model + **/ + virtual std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const = 0; + virtual uint32_t getDistributionKey() const = 0; /** diff --git a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp index e998e4d18bd..6e2e0b88fbb 100644 --- a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp +++ b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp @@ -54,4 +54,21 @@ IndexEnvironment::addConstantValue(const vespalib::string &name, (void) insertRes; } +std::optional<vespalib::string> +IndexEnvironment::getOnnxModelFullPath(const vespalib::string &name) const +{ + auto pos = _models.find(name); + if (pos != _models.end()) { + return pos->second; + } + return std::nullopt; +} + +void +IndexEnvironment::addOnnxModel(const vespalib::string &name, const vespalib::string &path) +{ + _models[name] = path; +} + + } diff --git a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h index d84cebc7f52..6602d9f8ee9 100644 --- a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h +++ b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h @@ -47,6 +47,7 @@ public: }; using ConstantsMap = std::map<vespalib::string, Constant>; + using ModelMap = std::map<vespalib::string, vespalib::string>; IndexEnvironment(); ~IndexEnvironment(); @@ -83,6 +84,9 @@ public: vespalib::eval::ValueType type, std::unique_ptr<vespalib::eval::Value> value); + std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const override; + void addOnnxModel(const vespalib::string &name, const vespalib::string &path); + private: IndexEnvironment(const IndexEnvironment &); // hide IndexEnvironment & operator=(const IndexEnvironment &); // hide @@ -93,6 +97,7 @@ private: AttributeMap _attrMap; TableManager _tableMan; ConstantsMap _constants; + ModelMap _models; }; } |