diff options
34 files changed, 312 insertions, 32 deletions
diff --git a/searchcore/src/apps/tests/persistenceconformance_test.cpp b/searchcore/src/apps/tests/persistenceconformance_test.cpp index 44fb2770594..26a44606898 100644 --- a/searchcore/src/apps/tests/persistenceconformance_test.cpp +++ b/searchcore/src/apps/tests/persistenceconformance_test.cpp @@ -127,6 +127,7 @@ public: 1, std::make_shared<RankProfilesConfig>(), std::make_shared<matching::RankingConstants>(), + std::make_shared<matching::OnnxModels>(), indexschema, attributes, summary, diff --git a/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp b/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp index 118cad4d8ef..7043c450047 100644 --- a/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp +++ b/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp @@ -12,6 +12,7 @@ #include <vespa/eval/tensor/default_tensor_engine.h> #include <vespa/searchcommon/common/schemaconfigurer.h> #include <vespa/searchcore/config/config-ranking-constants.h> +#include <vespa/searchcore/config/config-onnx-models.h> #include <vespa/searchcore/proton/matching/indexenvironment.h> #include <vespa/searchlib/features/setup.h> #include <vespa/searchlib/fef/fef.h> @@ -28,10 +29,12 @@ using config::ConfigSubscriber; using config::IConfigContext; using config::InvalidConfigException; using proton::matching::IConstantValueRepo; +using proton::matching::OnnxModels; using vespa::config::search::AttributesConfig; using vespa::config::search::IndexschemaConfig; using vespa::config::search::RankProfilesConfig; using vespa::config::search::core::RankingConstantsConfig; +using vespa::config::search::core::OnnxModelsConfig; using vespalib::eval::ConstantValue; using vespalib::eval::TensorSpec; using vespalib::eval::ValueType; @@ -39,17 +42,30 @@ using vespalib::tensor::DefaultTensorEngine; using vespalib::eval::SimpleConstantValue; using vespalib::eval::BadConstantValue; +OnnxModels make_models(const OnnxModelsConfig &modelsCfg) { + OnnxModels::Vector model_list; + for (const auto &entry: modelsCfg.model) { + // TODO(havardpe): resolve model path + vespalib::string model_path = entry.name; + model_path += ".onnx"; + model_list.emplace_back(entry.name, model_path); + } + return OnnxModels(model_list); +} + class App : public FastOS_Application { public: bool verify(const search::index::Schema &schema, const search::fef::Properties &props, - const IConstantValueRepo &repo); + const IConstantValueRepo &repo, + OnnxModels models); bool verifyConfig(const RankProfilesConfig &rankCfg, const IndexschemaConfig &schemaCfg, const AttributesConfig &attributeCfg, - const RankingConstantsConfig &constantsCfg); + const RankingConstantsConfig &constantsCfg, + const OnnxModelsConfig &modelsCfg); int usage(); int Main() override; @@ -77,9 +93,10 @@ struct DummyConstantValueRepo : IConstantValueRepo { bool App::verify(const search::index::Schema &schema, const search::fef::Properties &props, - const IConstantValueRepo &repo) + const IConstantValueRepo &repo, + OnnxModels models) { - proton::matching::IndexEnvironment indexEnv(0, schema, props, repo); + proton::matching::IndexEnvironment indexEnv(0, schema, props, repo, models); search::fef::BlueprintFactory factory; search::features::setup_search_features(factory); search::fef::test::setup_fef_test_plugin(factory); @@ -106,13 +123,15 @@ bool App::verifyConfig(const RankProfilesConfig &rankCfg, const IndexschemaConfig &schemaCfg, const AttributesConfig &attributeCfg, - const RankingConstantsConfig &constantsCfg) + const RankingConstantsConfig &constantsCfg, + const OnnxModelsConfig &modelsCfg) { bool ok = true; search::index::Schema schema; search::index::SchemaBuilder::build(schemaCfg, schema); search::index::SchemaBuilder::build(attributeCfg, schema); DummyConstantValueRepo repo(constantsCfg); + auto models = make_models(modelsCfg); for(size_t i = 0; i < rankCfg.rankprofile.size(); i++) { search::fef::Properties properties; const RankProfilesConfig::Rankprofile &profile = rankCfg.rankprofile[i]; @@ -120,7 +139,7 @@ App::verifyConfig(const RankProfilesConfig &rankCfg, properties.add(profile.fef.property[j].name, profile.fef.property[j].value); } - if (verify(schema, properties, repo)) { + if (verify(schema, properties, repo, models)) { LOG(info, "rank profile '%s': pass", profile.name.c_str()); } else { LOG(error, "rank profile '%s': FAIL", profile.name.c_str()); @@ -157,12 +176,14 @@ App::Main() ConfigHandle<AttributesConfig>::UP attributesHandle = subscriber.subscribe<AttributesConfig>(cfgId); ConfigHandle<IndexschemaConfig>::UP schemaHandle = subscriber.subscribe<IndexschemaConfig>(cfgId); ConfigHandle<RankingConstantsConfig>::UP constantsHandle = subscriber.subscribe<RankingConstantsConfig>(cfgId); + ConfigHandle<OnnxModelsConfig>::UP modelsHandle = subscriber.subscribe<OnnxModelsConfig>(cfgId); subscriber.nextConfig(); ok = verifyConfig(*rankHandle->getConfig(), *schemaHandle->getConfig(), *attributesHandle->getConfig(), - *constantsHandle->getConfig()); + *constantsHandle->getConfig(), + *modelsHandle->getConfig()); } catch (ConfigRuntimeException & e) { LOG(error, "Unable to subscribe to config: %s", e.getMessage().c_str()); } catch (InvalidConfigException & e) { diff --git a/searchcore/src/tests/proton/documentdb/configurer/configurer_test.cpp b/searchcore/src/tests/proton/documentdb/configurer/configurer_test.cpp index 1da35c9f5c3..b2903f00226 100644 --- a/searchcore/src/tests/proton/documentdb/configurer/configurer_test.cpp +++ b/searchcore/src/tests/proton/documentdb/configurer/configurer_test.cpp @@ -646,6 +646,7 @@ TEST("require that maintenance controller should change if some config has chang TEST_DO(assertMaintenanceControllerShouldChange(CCR().setRankProfilesChanged(true))); TEST_DO(assertMaintenanceControllerShouldChange(CCR().setRankingConstantsChanged(true))); + TEST_DO(assertMaintenanceControllerShouldChange(CCR().setOnnxModelsChanged(true))); TEST_DO(assertMaintenanceControllerShouldChange(CCR().setIndexschemaChanged(true))); TEST_DO(assertMaintenanceControllerShouldChange(CCR().setAttributesChanged(true))); TEST_DO(assertMaintenanceControllerShouldChange(CCR().setSummaryChanged(true))); @@ -692,6 +693,7 @@ TEST("require that subdbs should change if relevant config changed") TEST_DO(assertSubDbsShouldChange(CCR().setVisibilityDelayChanged(true))); TEST_DO(assertSubDbsShouldChange(CCR().setRankProfilesChanged(true))); TEST_DO(assertSubDbsShouldChange(CCR().setRankingConstantsChanged(true))); + TEST_DO(assertSubDbsShouldChange(CCR().setOnnxModelsChanged(true))); TEST_DO(assertSubDbsShouldChange(CCR().setSchemaChanged(true))); } diff --git a/searchcore/src/tests/proton/documentdb/documentdbconfig/documentdbconfig_test.cpp b/searchcore/src/tests/proton/documentdb/documentdbconfig/documentdbconfig_test.cpp index a2b824b88ba..aed01ca0192 100644 --- a/searchcore/src/tests/proton/documentdb/documentdbconfig/documentdbconfig_test.cpp +++ b/searchcore/src/tests/proton/documentdb/documentdbconfig/documentdbconfig_test.cpp @@ -17,6 +17,7 @@ using namespace search::index; using namespace search; using namespace vespa::config::search; using proton::matching::RankingConstants; +using proton::matching::OnnxModels; using std::make_shared; using std::shared_ptr; using document::config_builder::DocumenttypesConfigBuilderHelper; @@ -68,6 +69,11 @@ public: _builder.rankingConstants(make_shared<RankingConstants>(constants)); return *this; } + MyConfigBuilder &addOnnxModel() { + OnnxModels::Vector models = {{"my_model_name", "my_model_file"}}; + _builder.onnxModels(make_shared<OnnxModels>(models)); + return *this; + } MyConfigBuilder &addImportedField() { ImportedFieldsConfigBuilder builder; builder.attribute.resize(1); @@ -132,6 +138,7 @@ struct Fixture { fullCfg = MyConfigBuilder(4, schema, repo).addAttribute(). addRankProfile(). addRankingConstant(). + addOnnxModel(). addImportedField(). addSummary(true). addSummarymap(). @@ -166,12 +173,14 @@ struct DelayAttributeAspectFixture { attrCfg = MyConfigBuilder(4, schema, makeDocTypeRepo(true)).addAttribute(). addRankProfile(). addRankingConstant(). + addOnnxModel(). addImportedField(). addSummary(true). addSummarymap(). build(); noAttrCfg = MyConfigBuilder(4, schema, makeDocTypeRepo(hasDocField)).addRankProfile(). addRankingConstant(). + addOnnxModel(). addImportedField(). addSummary(hasDocField). build(); diff --git a/searchcore/src/tests/proton/documentdb/fileconfigmanager/fileconfigmanager_test.cpp b/searchcore/src/tests/proton/documentdb/fileconfigmanager/fileconfigmanager_test.cpp index 2782117d8ae..2352fda65a0 100644 --- a/searchcore/src/tests/proton/documentdb/fileconfigmanager/fileconfigmanager_test.cpp +++ b/searchcore/src/tests/proton/documentdb/fileconfigmanager/fileconfigmanager_test.cpp @@ -28,6 +28,7 @@ using namespace vespa::config::search; using namespace std::chrono_literals; using vespa::config::content::core::BucketspacesConfig; using proton::matching::RankingConstants; +using proton::matching::OnnxModels; typedef DocumentDBConfigHelper DBCM; typedef DocumentDBConfig::DocumenttypesConfigSP DocumenttypesConfigSP; @@ -77,7 +78,9 @@ assertEqualSnapshot(const DocumentDBConfig &exp, const DocumentDBConfig &act) { EXPECT_TRUE(exp.getRankProfilesConfig() == act.getRankProfilesConfig()); EXPECT_TRUE(exp.getRankingConstants() == act.getRankingConstants()); + EXPECT_TRUE(exp.getOnnxModels() == act.getOnnxModels()); EXPECT_EQUAL(0u, exp.getRankingConstants().size()); + EXPECT_EQUAL(0u, exp.getOnnxModels().size()); EXPECT_TRUE(exp.getIndexschemaConfig() == act.getIndexschemaConfig()); EXPECT_TRUE(exp.getAttributesConfig() == act.getAttributesConfig()); EXPECT_TRUE(exp.getSummaryConfig() == act.getSummaryConfig()); @@ -105,6 +108,9 @@ addConfigsThatAreNotSavedToDisk(const DocumentDBConfig &cfg) RankingConstants::Vector constants = {{"my_name", "my_type", "my_path"}}; builder.rankingConstants(std::make_shared<RankingConstants>(constants)); + OnnxModels::Vector models = {{"my_model_name", "my_model_file"}}; + builder.onnxModels(std::make_shared<OnnxModels>(models)); + ImportedFieldsConfigBuilder importedFields; importedFields.attribute.resize(1); importedFields.attribute.back().name = "my_name"; diff --git a/searchcore/src/tests/proton/matching/index_environment/index_environment_test.cpp b/searchcore/src/tests/proton/matching/index_environment/index_environment_test.cpp index 932ab6f4d14..508a60480d0 100644 --- a/searchcore/src/tests/proton/matching/index_environment/index_environment_test.cpp +++ b/searchcore/src/tests/proton/matching/index_environment/index_environment_test.cpp @@ -14,6 +14,13 @@ using search::index::schema::DataType; using vespalib::eval::ConstantValue; using SIAF = Schema::ImportedAttributeField; +OnnxModels make_models() { + OnnxModels::Vector list; + list.emplace_back("model1", "path1"); + list.emplace_back("model2", "path2"); + return OnnxModels(list); +} + struct MyConstantValueRepo : public IConstantValueRepo { virtual ConstantValue::UP getConstant(const vespalib::string &) const override { return ConstantValue::UP(); @@ -42,7 +49,7 @@ struct Fixture { Fixture(Schema::UP schema_) : repo(), schema(std::move(schema_)), - env(7, *schema, Properties(), repo) + env(7, *schema, Properties(), repo, make_models()) { } const FieldInfo *assertField(size_t idx, @@ -97,4 +104,10 @@ TEST_F("require that imported attribute fields are extracted in index environmen EXPECT_EQUAL("[documentmetastore]", f.env.getField(2)->name()); } +TEST_F("require that onnx model paths can be obtained", Fixture(buildEmptySchema())) { + EXPECT_EQUAL(f1.env.getOnnxModelFullPath("model1").value(), vespalib::string("path1")); + EXPECT_EQUAL(f1.env.getOnnxModelFullPath("model2").value(), vespalib::string("path2")); + EXPECT_FALSE(f1.env.getOnnxModelFullPath("model3").has_value()); +} + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/searchcore/src/tests/proton/matching/matching_test.cpp b/searchcore/src/tests/proton/matching/matching_test.cpp index 9d5b67af81c..0ea63bce859 100644 --- a/searchcore/src/tests/proton/matching/matching_test.cpp +++ b/searchcore/src/tests/proton/matching/matching_test.cpp @@ -278,7 +278,7 @@ struct MyWorld { } Matcher::SP createMatcher() { - return std::make_shared<Matcher>(schema, config, clock, queryLimiter, constantValueRepo, 0); + return std::make_shared<Matcher>(schema, config, clock, queryLimiter, constantValueRepo, OnnxModels(), 0); } struct MySearchHandler : ISearchHandler { diff --git a/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp b/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp index a947074a917..1e64a8f4ecb 100644 --- a/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp +++ b/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp @@ -8,6 +8,7 @@ #include <vespa/searchcore/proton/server/i_proton_configurer.h> #include <vespa/searchcore/proton/common/hw_info.h> #include <vespa/searchcore/config/config-ranking-constants.h> +#include <vespa/searchcore/config/config-onnx-models.h> #include <vespa/searchsummary/config/config-juniperrc.h> #include <vespa/document/repo/documenttyperepo.h> #include <vespa/fileacquirer/config-filedistributorrpc.h> @@ -45,6 +46,7 @@ struct DoctypeFixture { AttributesConfigBuilder attributesBuilder; RankProfilesConfigBuilder rankProfilesBuilder; RankingConstantsConfigBuilder rankingConstantsBuilder; + OnnxModelsConfigBuilder onnxModelsBuilder; IndexschemaConfigBuilder indexschemaBuilder; SummaryConfigBuilder summaryBuilder; SummarymapConfigBuilder summarymapBuilder; @@ -100,6 +102,7 @@ struct ConfigTestFixture { set.addBuilder(db.configid, &fixture->attributesBuilder); set.addBuilder(db.configid, &fixture->rankProfilesBuilder); set.addBuilder(db.configid, &fixture->rankingConstantsBuilder); + set.addBuilder(db.configid, &fixture->onnxModelsBuilder); set.addBuilder(db.configid, &fixture->indexschemaBuilder); set.addBuilder(db.configid, &fixture->summaryBuilder); set.addBuilder(db.configid, &fixture->summarymapBuilder); @@ -253,7 +256,7 @@ TEST_FF("require that documentdb config manager subscribes for config", DocumentDBConfigManager(f1.configId + "/typea", "typea")) { f1.addDocType("typea"); const ConfigKeySet keySet(f2.createConfigKeySet()); - ASSERT_EQUAL(8u, keySet.size()); + ASSERT_EQUAL(9u, keySet.size()); ASSERT_TRUE(f1.configEqual("typea", getDocumentDBConfig(f1, f2))); } diff --git a/searchcore/src/tests/proton/proton_configurer/proton_configurer_test.cpp b/searchcore/src/tests/proton/proton_configurer/proton_configurer_test.cpp index 83706d966ae..6190177ac9d 100644 --- a/searchcore/src/tests/proton/proton_configurer/proton_configurer_test.cpp +++ b/searchcore/src/tests/proton/proton_configurer/proton_configurer_test.cpp @@ -19,6 +19,7 @@ #include <vespa/searchcore/proton/server/i_proton_disk_layout.h> #include <vespa/searchsummary/config/config-juniperrc.h> #include <vespa/searchcore/config/config-ranking-constants.h> +#include <vespa/searchcore/config/config-onnx-models.h> #include <vespa/vespalib/gtest/gtest.h> #include <vespa/searchcommon/common/schemaconfigurer.h> #include <vespa/vespalib/util/threadstackexecutor.h> @@ -44,12 +45,14 @@ using std::map; using search::index::Schema; using search::index::SchemaBuilder; using proton::matching::RankingConstants; +using proton::matching::OnnxModels; struct DBConfigFixture { using UP = std::unique_ptr<DBConfigFixture>; AttributesConfigBuilder _attributesBuilder; RankProfilesConfigBuilder _rankProfilesBuilder; RankingConstantsConfigBuilder _rankingConstantsBuilder; + OnnxModelsConfigBuilder _onnxModelsBuilder; IndexschemaConfigBuilder _indexschemaBuilder; SummaryConfigBuilder _summaryBuilder; SummarymapConfigBuilder _summarymapBuilder; @@ -70,6 +73,11 @@ struct DBConfigFixture { return std::make_shared<RankingConstants>(); } + OnnxModels::SP buildOnnxModels() + { + return std::make_shared<OnnxModels>(); + } + DocumentDBConfig::SP getConfig(int64_t generation, std::shared_ptr<DocumenttypesConfig> documentTypes, std::shared_ptr<const DocumentTypeRepo> repo, @@ -80,6 +88,7 @@ struct DBConfigFixture { (generation, std::make_shared<RankProfilesConfig>(_rankProfilesBuilder), buildRankingConstants(), + buildOnnxModels(), std::make_shared<IndexschemaConfig>(_indexschemaBuilder), std::make_shared<AttributesConfig>(_attributesBuilder), std::make_shared<SummaryConfig>(_summaryBuilder), diff --git a/searchcore/src/vespa/searchcore/config/CMakeLists.txt b/searchcore/src/vespa/searchcore/config/CMakeLists.txt index a4f5560c712..915ab147978 100644 --- a/searchcore/src/vespa/searchcore/config/CMakeLists.txt +++ b/searchcore/src/vespa/searchcore/config/CMakeLists.txt @@ -9,4 +9,6 @@ vespa_generate_config(searchcore_fconfig proton.def) install_config_definition(proton.def vespa.config.search.core.proton.def) vespa_generate_config(searchcore_fconfig ranking-constants.def) install_config_definition(ranking-constants.def vespa.config.search.core.ranking-constants.def) +vespa_generate_config(searchcore_fconfig onnx-models.def) +install_config_definition(onnx-models.def vespa.config.search.core.onnx-models.def) vespa_generate_config(searchcore_fconfig hwinfo.def) diff --git a/searchcore/src/vespa/searchcore/proton/matching/CMakeLists.txt b/searchcore/src/vespa/searchcore/proton/matching/CMakeLists.txt index ffbab597118..a4688b5fdca 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/CMakeLists.txt +++ b/searchcore/src/vespa/searchcore/proton/matching/CMakeLists.txt @@ -20,6 +20,7 @@ vespa_add_library(searchcore_matching STATIC match_tools.cpp matcher.cpp matching_stats.cpp + onnx_models.cpp partial_result.cpp query.cpp queryenvironment.cpp diff --git a/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.cpp b/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.cpp index d6d185ccab4..5743a3d44d6 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.cpp @@ -64,13 +64,15 @@ IndexEnvironment::insertField(const search::fef::FieldInfo &field) IndexEnvironment::IndexEnvironment(uint32_t distributionKey, const search::index::Schema &schema, const search::fef::Properties &props, - const IConstantValueRepo &constantValueRepo) + const IConstantValueRepo &constantValueRepo, + OnnxModels onnxModels) : _tableManager(), _properties(props), _fieldNames(), _fields(), _motivation(UNKNOWN), _constantValueRepo(constantValueRepo), + _onnxModels(std::move(onnxModels)), _distributionKey(distributionKey) { _tableManager.addFactory(std::make_shared<search::fef::FunctionTableFactory>(256)); @@ -129,6 +131,15 @@ IndexEnvironment::hintFieldAccess(uint32_t ) const { } void IndexEnvironment::hintAttributeAccess(const string &) const { } +std::optional<vespalib::string> +IndexEnvironment::getOnnxModelFullPath(const vespalib::string &name) const +{ + if (const auto model = _onnxModels.getModel(name)) { + return model->filePath; + } + return std::nullopt; +} + IndexEnvironment::~IndexEnvironment() = default; } diff --git a/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.h b/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.h index 7da45909577..d0e9a516cd0 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.h +++ b/searchcore/src/vespa/searchcore/proton/matching/indexenvironment.h @@ -2,6 +2,7 @@ #pragma once +#include "onnx_models.h" #include "i_constant_value_repo.h" #include <vespa/searchlib/fef/fieldinfo.h> #include <vespa/searchlib/fef/iindexenvironment.h> @@ -25,6 +26,7 @@ private: std::vector<search::fef::FieldInfo> _fields; mutable FeatureMotivation _motivation; const IConstantValueRepo &_constantValueRepo; + OnnxModels _onnxModels; uint32_t _distributionKey; @@ -44,11 +46,13 @@ public: * @param schema the index schema * @param props config * @param constantValueRepo repo used to access constant values for ranking + * @param onnxModels processed config about onnx models **/ IndexEnvironment(uint32_t distributionKey, const search::index::Schema &schema, const search::fef::Properties &props, - const IConstantValueRepo &constantValueRepo); + const IConstantValueRepo &constantValueRepo, + OnnxModels onnxModels); const search::fef::Properties &getProperties() const override; uint32_t getNumFields() const override; @@ -65,6 +69,7 @@ public: return _constantValueRepo.getConstant(name); } + std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const override; ~IndexEnvironment() override; }; diff --git a/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp b/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp index 735070002eb..98c4fdaa89a 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp @@ -99,8 +99,8 @@ handleGroupingSession(SessionManager &sessionMgr, GroupingContext & groupingCont } // namespace proton::matching::<unnamed> Matcher::Matcher(const search::index::Schema &schema, const Properties &props, const vespalib::Clock &clock, - QueryLimiter &queryLimiter, const IConstantValueRepo &constantValueRepo, uint32_t distributionKey) - : _indexEnv(distributionKey, schema, props, constantValueRepo), + QueryLimiter &queryLimiter, const IConstantValueRepo &constantValueRepo, OnnxModels onnxModels, uint32_t distributionKey) + : _indexEnv(distributionKey, schema, props, constantValueRepo, std::move(onnxModels)), _blueprintFactory(), _rankSetup(), _viewResolver(ViewResolver::createFromSchema(schema)), diff --git a/searchcore/src/vespa/searchcore/proton/matching/matcher.h b/searchcore/src/vespa/searchcore/proton/matching/matcher.h index 243fdad63ae..39d1fa38007 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/matcher.h +++ b/searchcore/src/vespa/searchcore/proton/matching/matcher.h @@ -89,7 +89,8 @@ public: **/ Matcher(const search::index::Schema &schema, const Properties &props, const vespalib::Clock &clock, QueryLimiter &queryLimiter, - const IConstantValueRepo &constantValueRepo, uint32_t distributionKey); + const IConstantValueRepo &constantValueRepo, OnnxModels onnxModels, + uint32_t distributionKey); const search::fef::IIndexEnvironment &get_index_env() const { return _indexEnv; } diff --git a/searchcore/src/vespa/searchcore/proton/matching/onnx_models.cpp b/searchcore/src/vespa/searchcore/proton/matching/onnx_models.cpp new file mode 100644 index 00000000000..bdcf3e21d8e --- /dev/null +++ b/searchcore/src/vespa/searchcore/proton/matching/onnx_models.cpp @@ -0,0 +1,54 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "onnx_models.h" + +namespace proton::matching { + +OnnxModels::Model::Model(const vespalib::string &name_in, + const vespalib::string &filePath_in) + : name(name_in), + filePath(filePath_in) +{ +} + +OnnxModels::Model::~Model() = default; + +bool +OnnxModels::Model::operator==(const Model &rhs) const +{ + return (name == rhs.name) && + (filePath == rhs.filePath); +} + +OnnxModels::OnnxModels() + : _models() +{ +} + +OnnxModels::~OnnxModels() = default; + +OnnxModels::OnnxModels(const Vector &models) + : _models() +{ + for (const auto &model : models) { + _models.insert(std::make_pair(model.name, model)); + } +} + +bool +OnnxModels::operator==(const OnnxModels &rhs) const +{ + return _models == rhs._models; +} + +const OnnxModels::Model * +OnnxModels::getModel(const vespalib::string &name) const +{ + auto itr = _models.find(name); + if (itr != _models.end()) { + return &itr->second; + } + return nullptr; +} + +} diff --git a/searchcore/src/vespa/searchcore/proton/matching/onnx_models.h b/searchcore/src/vespa/searchcore/proton/matching/onnx_models.h new file mode 100644 index 00000000000..fdaae657711 --- /dev/null +++ b/searchcore/src/vespa/searchcore/proton/matching/onnx_models.h @@ -0,0 +1,43 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/vespalib/stllike/string.h> +#include <map> +#include <vector> + +namespace proton::matching { + +/** + * Class representing a set of configured onnx models, with full path + * for where the models are stored on disk. + */ +class OnnxModels { +public: + struct Model { + vespalib::string name; + vespalib::string filePath; + + Model(const vespalib::string &name_in, + const vespalib::string &filePath_in); + ~Model(); + bool operator==(const Model &rhs) const; + }; + + using Vector = std::vector<Model>; + +private: + using Map = std::map<vespalib::string, Model>; + Map _models; + +public: + using SP = std::shared_ptr<OnnxModels>; + OnnxModels(); + OnnxModels(const Vector &models); + ~OnnxModels(); + bool operator==(const OnnxModels &rhs) const; + const Model *getModel(const vespalib::string &name) const; + size_t size() const { return _models.size(); } +}; + +} diff --git a/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.cpp b/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.cpp index 712bc553d08..8bcf8440101 100644 --- a/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.cpp @@ -11,6 +11,7 @@ #include <vespa/document/config/config-documenttypes.h> #include <vespa/document/repo/documenttyperepo.h> #include <vespa/searchcore/config/config-ranking-constants.h> +#include <vespa/searchcore/config/config-onnx-models.h> #include <vespa/searchcore/proton/attribute/attribute_aspect_delayer.h> #include <vespa/searchcore/proton/common/document_type_inspector.h> #include <vespa/searchcore/proton/common/indexschema_inspector.h> @@ -25,12 +26,14 @@ using search::TuneFileDocumentDB; using search::index::Schema; using vespa::config::search::SummarymapConfig; using vespa::config::search::core::RankingConstantsConfig; +using vespa::config::search::core::OnnxModelsConfig; namespace proton { DocumentDBConfig::ComparisonResult::ComparisonResult() : rankProfilesChanged(false), rankingConstantsChanged(false), + onnxModelsChanged(false), indexschemaChanged(false), attributesChanged(false), summaryChanged(false), @@ -51,6 +54,7 @@ DocumentDBConfig::DocumentDBConfig( int64_t generation, const RankProfilesConfigSP &rankProfiles, const RankingConstants::SP &rankingConstants, + const OnnxModels::SP &onnxModels, const IndexschemaConfigSP &indexschema, const AttributesConfigSP &attributes, const SummaryConfigSP &summary, @@ -70,6 +74,7 @@ DocumentDBConfig::DocumentDBConfig( _generation(generation), _rankProfiles(rankProfiles), _rankingConstants(rankingConstants), + _onnxModels(onnxModels), _indexschema(indexschema), _attributes(attributes), _summary(summary), @@ -94,6 +99,7 @@ DocumentDBConfig(const DocumentDBConfig &cfg) _generation(cfg._generation), _rankProfiles(cfg._rankProfiles), _rankingConstants(cfg._rankingConstants), + _onnxModels(cfg._onnxModels), _indexschema(cfg._indexschema), _attributes(cfg._attributes), _summary(cfg._summary), @@ -117,6 +123,7 @@ DocumentDBConfig::operator==(const DocumentDBConfig & rhs) const { return equals<RankProfilesConfig>(_rankProfiles.get(), rhs._rankProfiles.get()) && equals<RankingConstants>(_rankingConstants.get(), rhs._rankingConstants.get()) && + equals<OnnxModels>(_onnxModels.get(), rhs._onnxModels.get()) && equals<IndexschemaConfig>(_indexschema.get(), rhs._indexschema.get()) && equals<AttributesConfig>(_attributes.get(), rhs._attributes.get()) && equals<SummaryConfig>(_summary.get(), rhs._summary.get()) && @@ -138,6 +145,7 @@ DocumentDBConfig::compare(const DocumentDBConfig &rhs) const ComparisonResult retval; retval.rankProfilesChanged = !equals<RankProfilesConfig>(_rankProfiles.get(), rhs._rankProfiles.get()); retval.rankingConstantsChanged = !equals<RankingConstants>(_rankingConstants.get(), rhs._rankingConstants.get()); + retval.onnxModelsChanged = !equals<OnnxModels>(_onnxModels.get(), rhs._onnxModels.get()); retval.indexschemaChanged = !equals<IndexschemaConfig>(_indexschema.get(), rhs._indexschema.get()); retval.attributesChanged = !equals<AttributesConfig>(_attributes.get(), rhs._attributes.get()); retval.summaryChanged = !equals<SummaryConfig>(_summary.get(), rhs._summary.get()); @@ -161,6 +169,7 @@ DocumentDBConfig::valid() const { return _rankProfiles && _rankingConstants && + _onnxModels && _indexschema && _attributes && _summary && @@ -201,6 +210,7 @@ DocumentDBConfig::makeReplayConfig(const SP & orig) o._generation, emptyConfig(o._rankProfiles), std::make_shared<RankingConstants>(), + std::make_shared<OnnxModels>(), o._indexschema, o._attributes, o._summary, @@ -241,6 +251,7 @@ DocumentDBConfig::newFromAttributesConfig(const AttributesConfigSP &attributes) _generation, _rankProfiles, _rankingConstants, + _onnxModels, _indexschema, attributes, _summary, @@ -276,6 +287,7 @@ DocumentDBConfig::makeDelayedAttributeAspectConfig(const SP &newCfg, const Docum (n._generation, n._rankProfiles, n._rankingConstants, + n._onnxModels, n._indexschema, attributeAspectDelayer.getAttributesConfig(), n._summary, diff --git a/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.h b/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.h index c4083c3db7a..09fdd5b5b0a 100644 --- a/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.h +++ b/searchcore/src/vespa/searchcore/proton/server/documentdbconfig.h @@ -6,6 +6,7 @@ #include <vespa/searchlib/common/tunefileinfo.h> #include <vespa/searchcommon/common/schema.h> #include <vespa/searchcore/proton/matching/ranking_constants.h> +#include <vespa/searchcore/proton/matching/onnx_models.h> #include <vespa/config/retriever/configkeyset.h> #include <vespa/config/retriever/configsnapshot.h> #include <vespa/searchlib/docstore/logdocumentstore.h> @@ -36,6 +37,7 @@ public: public: bool rankProfilesChanged; bool rankingConstantsChanged; + bool onnxModelsChanged; bool indexschemaChanged; bool attributesChanged; bool summaryChanged; @@ -54,6 +56,7 @@ public: ComparisonResult(); ComparisonResult &setRankProfilesChanged(bool val) { rankProfilesChanged = val; return *this; } ComparisonResult &setRankingConstantsChanged(bool val) { rankingConstantsChanged = val; return *this; } + ComparisonResult &setOnnxModelsChanged(bool val) { onnxModelsChanged = val; return *this; } ComparisonResult &setIndexschemaChanged(bool val) { indexschemaChanged = val; return *this; } ComparisonResult &setAttributesChanged(bool val) { attributesChanged = val; return *this; } ComparisonResult &setSummaryChanged(bool val) { summaryChanged = val; return *this; } @@ -91,6 +94,7 @@ public: using RankProfilesConfig = const vespa::config::search::internal::InternalRankProfilesType; using RankProfilesConfigSP = std::shared_ptr<RankProfilesConfig>; using RankingConstants = matching::RankingConstants; + using OnnxModels = matching::OnnxModels; using SummaryConfig = const vespa::config::search::internal::InternalSummaryType; using SummaryConfigSP = std::shared_ptr<SummaryConfig>; using SummarymapConfig = const vespa::config::search::internal::InternalSummarymapType; @@ -109,6 +113,7 @@ private: int64_t _generation; RankProfilesConfigSP _rankProfiles; RankingConstants::SP _rankingConstants; + OnnxModels::SP _onnxModels; IndexschemaConfigSP _indexschema; AttributesConfigSP _attributes; SummaryConfigSP _summary; @@ -145,6 +150,7 @@ public: DocumentDBConfig(int64_t generation, const RankProfilesConfigSP &rankProfiles, const RankingConstants::SP &rankingConstants, + const OnnxModels::SP &onnxModels, const IndexschemaConfigSP &indexschema, const AttributesConfigSP &attributes, const SummaryConfigSP &summary, @@ -172,6 +178,7 @@ public: const RankProfilesConfig &getRankProfilesConfig() const { return *_rankProfiles; } const RankingConstants &getRankingConstants() const { return *_rankingConstants; } + const OnnxModels &getOnnxModels() const { return *_onnxModels; } const IndexschemaConfig &getIndexschemaConfig() const { return *_indexschema; } const AttributesConfig &getAttributesConfig() const { return *_attributes; } const SummaryConfig &getSummaryConfig() const { return *_summary; } @@ -180,6 +187,7 @@ public: const DocumenttypesConfig &getDocumenttypesConfig() const { return *_documenttypes; } const RankProfilesConfigSP &getRankProfilesConfigSP() const { return _rankProfiles; } const RankingConstants::SP &getRankingConstantsSP() const { return _rankingConstants; } + const OnnxModels::SP &getOnnxModelsSP() const { return _onnxModels; } const IndexschemaConfigSP &getIndexschemaConfigSP() const { return _indexschema; } const AttributesConfigSP &getAttributesConfigSP() const { return _attributes; } const SummaryConfigSP &getSummaryConfigSP() const { return _summary; } diff --git a/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp b/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp index 68e65acb87d..a8996abc856 100644 --- a/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/documentdbconfigmanager.cpp @@ -4,6 +4,7 @@ #include "bootstrapconfig.h" #include <vespa/searchcore/proton/common/hw_info.h> #include <vespa/searchcore/config/config-ranking-constants.h> +#include <vespa/searchcore/config/config-onnx-models.h> #include <vespa/config-imported-fields.h> #include <vespa/config-rank-profiles.h> #include <vespa/config-summarymap.h> @@ -30,6 +31,7 @@ using search::TuneFileDocumentDB; using search::index::Schema; using search::index::SchemaBuilder; using proton::matching::RankingConstants; +using proton::matching::OnnxModels; using vespalib::compression::CompressionConfig; using search::LogDocumentStore; using search::LogDataStore; @@ -46,6 +48,7 @@ DocumentDBConfigManager::createConfigKeySet() const ConfigKeySet set; set.add<RankProfilesConfig, RankingConstantsConfig, + OnnxModelsConfig, IndexschemaConfig, AttributesConfig, SummaryConfig, @@ -228,6 +231,7 @@ DocumentDBConfigManager::update(const ConfigSnapshot &snapshot) { using RankProfilesConfigSP = DocumentDBConfig::RankProfilesConfigSP; using RankingConstantsConfigSP = std::shared_ptr<vespa::config::search::core::RankingConstantsConfig>; + using OnnxModelsConfigSP = std::shared_ptr<vespa::config::search::core::OnnxModelsConfig>; using IndexschemaConfigSP = DocumentDBConfig::IndexschemaConfigSP; using SummaryConfigSP = DocumentDBConfig::SummaryConfigSP; using SummarymapConfigSP = DocumentDBConfig::SummarymapConfigSP; @@ -238,6 +242,7 @@ DocumentDBConfigManager::update(const ConfigSnapshot &snapshot) DocumentDBConfig::SP current = _pendingConfigSnapshot; RankProfilesConfigSP newRankProfilesConfig; matching::RankingConstants::SP newRankingConstants; + matching::OnnxModels::SP newOnnxModels; IndexschemaConfigSP newIndexschemaConfig; MaintenanceConfigSP oldMaintenanceConfig; MaintenanceConfigSP newMaintenanceConfig; @@ -261,6 +266,7 @@ DocumentDBConfigManager::update(const ConfigSnapshot &snapshot) if (current) { newRankProfilesConfig = current->getRankProfilesConfigSP(); newRankingConstants = current->getRankingConstantsSP(); + newOnnxModels = current->getOnnxModelsSP(); newIndexschemaConfig = current->getIndexschemaConfigSP(); oldMaintenanceConfig = current->getMaintenanceConfigSP(); currentGeneration = current->getGeneration(); @@ -294,6 +300,31 @@ DocumentDBConfigManager::update(const ConfigSnapshot &snapshot) } newRankingConstants = std::make_shared<RankingConstants>(constants); } + if (snapshot.isChanged<OnnxModelsConfig>(_configId, currentGeneration)) { + OnnxModelsConfigSP newOnnxModelsConfig = OnnxModelsConfigSP( + snapshot.getConfig<OnnxModelsConfig>(_configId)); + const vespalib::string &spec = _bootstrapConfig->getFiledistributorrpcConfig().connectionspec; + OnnxModels::Vector models; + if (spec != "") { + config::RpcFileAcquirer fileAcquirer(spec); + vespalib::TimeBox timeBox(5*60, 5); + for (const OnnxModelsConfig::Model &rc : newOnnxModelsConfig->model) { + vespalib::string filePath; + LOG(info, "Waiting for file acquirer (name='%s', ref='%s')", + rc.name.c_str(), rc.fileref.c_str()); + while (timeBox.hasTimeLeft() && (filePath == "")) { + filePath = fileAcquirer.wait_for(rc.fileref, timeBox.timeLeft()); + if (filePath == "") { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + } + LOG(info, "Got file path from file acquirer: '%s' (name='%s', ref='%s')", + filePath.c_str(), rc.name.c_str(), rc.fileref.c_str()); + models.emplace_back(rc.name, filePath); + } + } + newOnnxModels = std::make_shared<OnnxModels>(models); + } if (snapshot.isChanged<IndexschemaConfig>(_configId, currentGeneration)) { newIndexschemaConfig = snapshot.getConfig<IndexschemaConfig>(_configId); search::index::Schema schema; @@ -318,6 +349,7 @@ DocumentDBConfigManager::update(const ConfigSnapshot &snapshot) auto newSnapshot = std::make_shared<DocumentDBConfig>(generation, newRankProfilesConfig, newRankingConstants, + newOnnxModels, newIndexschemaConfig, filterImportedAttributes(newAttributesConfig), newSummaryConfig, diff --git a/searchcore/src/vespa/searchcore/proton/server/fileconfigmanager.cpp b/searchcore/src/vespa/searchcore/proton/server/fileconfigmanager.cpp index 395eb5a0ea2..e66043aa422 100644 --- a/searchcore/src/vespa/searchcore/proton/server/fileconfigmanager.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/fileconfigmanager.cpp @@ -345,6 +345,7 @@ FileConfigManager::loadConfig(const DocumentDBConfig ¤tSnapshot, config::DirSpec spec(snapDir); addEmptyFile(snapDir, "ranking-constants.cfg"); + addEmptyFile(snapDir, "onnx-models.cfg"); addEmptyFile(snapDir, "imported-fields.cfg"); DocumentDBConfigHelper dbc(spec, _docTypeName); diff --git a/searchcore/src/vespa/searchcore/proton/server/matchers.cpp b/searchcore/src/vespa/searchcore/proton/server/matchers.cpp index 29e940ca26d..53c96a81134 100644 --- a/searchcore/src/vespa/searchcore/proton/server/matchers.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/matchers.cpp @@ -2,16 +2,19 @@ #include "matchers.h" #include <vespa/searchcore/proton/matching/matcher.h> +#include <vespa/searchcore/proton/matching/onnx_models.h> #include <vespa/vespalib/stllike/hash_map.hpp> namespace proton { +using matching::OnnxModels; + Matchers::Matchers(const vespalib::Clock &clock, matching::QueryLimiter &queryLimiter, const matching::IConstantValueRepo &constantValueRepo) : _rpmap(), _fallback(new matching::Matcher(search::index::Schema(), search::fef::Properties(), - clock, queryLimiter, constantValueRepo, -1)), + clock, queryLimiter, constantValueRepo, OnnxModels(), -1)), _default() { } diff --git a/searchcore/src/vespa/searchcore/proton/server/reconfig_params.cpp b/searchcore/src/vespa/searchcore/proton/server/reconfig_params.cpp index 8ec41ae3e3c..4fc241571ac 100644 --- a/searchcore/src/vespa/searchcore/proton/server/reconfig_params.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/reconfig_params.cpp @@ -15,6 +15,7 @@ ReconfigParams::configHasChanged() const { return _res.rankProfilesChanged || _res.rankingConstantsChanged || + _res.onnxModelsChanged || _res.indexschemaChanged || _res.attributesChanged || _res.summaryChanged || @@ -38,7 +39,7 @@ ReconfigParams::shouldSchemaChange() const bool ReconfigParams::shouldMatchersChange() const { - return _res.rankProfilesChanged || _res.rankingConstantsChanged || shouldSchemaChange(); + return _res.rankProfilesChanged || _res.rankingConstantsChanged || _res.onnxModelsChanged || shouldSchemaChange(); } bool diff --git a/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.cpp b/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.cpp index 713256fd809..8f34484dfe2 100644 --- a/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.cpp @@ -21,6 +21,7 @@ using vespa::config::search::RankProfilesConfig; namespace proton { using matching::Matcher; +using matching::OnnxModels; typedef AttributeReprocessingInitializer::Config ARIConfig; @@ -122,7 +123,8 @@ SearchableDocSubDBConfigurer::~SearchableDocSubDBConfigurer() = default; Matchers::UP SearchableDocSubDBConfigurer::createMatchers(const Schema::SP &schema, - const RankProfilesConfig &cfg) + const RankProfilesConfig &cfg, + const OnnxModels &onnxModels) { auto newMatchers = std::make_unique<Matchers>(_clock, _queryLimiter, _constantValueRepo); for (const auto &profile : cfg.rankprofile) { @@ -132,7 +134,7 @@ SearchableDocSubDBConfigurer::createMatchers(const Schema::SP &schema, properties.add(property.name, property.value); } // schema instance only used during call. - auto profptr = std::make_shared<Matcher>(*schema, properties, _clock, _queryLimiter, _constantValueRepo, _distributionKey); + auto profptr = std::make_shared<Matcher>(*schema, properties, _clock, _queryLimiter, _constantValueRepo, onnxModels, _distributionKey); newMatchers->add(name, profptr); } return newMatchers; @@ -200,7 +202,9 @@ SearchableDocSubDBConfigurer::reconfigure(const DocumentDBConfig &newConfig, Matchers::SP matchers = searchView->getMatchers(); if (params.shouldMatchersChange()) { _constantValueRepo.reconfigure(newConfig.getRankingConstants()); - Matchers::SP newMatchers = createMatchers(newConfig.getSchemaSP(),newConfig.getRankProfilesConfig()); + Matchers::SP newMatchers = createMatchers(newConfig.getSchemaSP(), + newConfig.getRankProfilesConfig(), + newConfig.getOnnxModels()); matchers = newMatchers; shouldMatchViewChange = true; } diff --git a/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.h b/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.h index 6b836544735..0f86520fd0b 100644 --- a/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.h +++ b/searchcore/src/vespa/searchcore/proton/server/searchable_doc_subdb_configurer.h @@ -80,7 +80,8 @@ public: ~SearchableDocSubDBConfigurer(); Matchers::UP createMatchers(const search::index::Schema::SP &schema, - const vespa::config::search::RankProfilesConfig &cfg); + const vespa::config::search::RankProfilesConfig &cfg, + const proton::matching::OnnxModels &onnxModels); void reconfigureIndexSearchable(); diff --git a/searchcore/src/vespa/searchcore/proton/server/searchabledocsubdb.cpp b/searchcore/src/vespa/searchcore/proton/server/searchabledocsubdb.cpp index 592d1bc1b52..23ab568c767 100644 --- a/searchcore/src/vespa/searchcore/proton/server/searchabledocsubdb.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/searchabledocsubdb.cpp @@ -201,7 +201,7 @@ SearchableDocSubDB::initViews(const DocumentDBConfig &configSnapshot, const Sess const Schema::SP &schema = configSnapshot.getSchemaSP(); const IIndexManager::SP &indexMgr = getIndexManager(); _constantValueRepo.reconfigure(configSnapshot.getRankingConstants()); - Matchers::SP matchers(_configurer.createMatchers(schema, configSnapshot.getRankProfilesConfig()).release()); + Matchers::SP matchers = _configurer.createMatchers(schema, configSnapshot.getRankProfilesConfig(), configSnapshot.getOnnxModels()); auto matchView = std::make_shared<MatchView>(std::move(matchers), indexMgr->getSearchable(), attrMgr, sessionManager, _metaStoreCtx, _docIdLimit); _rSearchView.set(SearchView::create( diff --git a/searchcore/src/vespa/searchcore/proton/test/documentdb_config_builder.cpp b/searchcore/src/vespa/searchcore/proton/test/documentdb_config_builder.cpp index 5cd092d20ba..a2366a3cb92 100644 --- a/searchcore/src/vespa/searchcore/proton/test/documentdb_config_builder.cpp +++ b/searchcore/src/vespa/searchcore/proton/test/documentdb_config_builder.cpp @@ -31,6 +31,7 @@ DocumentDBConfigBuilder::DocumentDBConfigBuilder(int64_t generation, : _generation(generation), _rankProfiles(std::make_shared<RankProfilesConfig>()), _rankingConstants(std::make_shared<matching::RankingConstants>()), + _onnxModels(std::make_shared<matching::OnnxModels>()), _indexschema(std::make_shared<IndexschemaConfig>()), _attributes(std::make_shared<AttributesConfig>()), _summary(std::make_shared<SummaryConfig>()), @@ -52,6 +53,7 @@ DocumentDBConfigBuilder::DocumentDBConfigBuilder(const DocumentDBConfig &cfg) : _generation(cfg.getGeneration()), _rankProfiles(cfg.getRankProfilesConfigSP()), _rankingConstants(cfg.getRankingConstantsSP()), + _onnxModels(cfg.getOnnxModelsSP()), _indexschema(cfg.getIndexschemaConfigSP()), _attributes(cfg.getAttributesConfigSP()), _summary(cfg.getSummaryConfigSP()), @@ -77,6 +79,7 @@ DocumentDBConfigBuilder::build() _generation, _rankProfiles, _rankingConstants, + _onnxModels, _indexschema, _attributes, _summary, diff --git a/searchcore/src/vespa/searchcore/proton/test/documentdb_config_builder.h b/searchcore/src/vespa/searchcore/proton/test/documentdb_config_builder.h index 4a515cf3b19..68fb5454eef 100644 --- a/searchcore/src/vespa/searchcore/proton/test/documentdb_config_builder.h +++ b/searchcore/src/vespa/searchcore/proton/test/documentdb_config_builder.h @@ -14,6 +14,7 @@ private: int64_t _generation; DocumentDBConfig::RankProfilesConfigSP _rankProfiles; DocumentDBConfig::RankingConstants::SP _rankingConstants; + DocumentDBConfig::OnnxModels::SP _onnxModels; DocumentDBConfig::IndexschemaConfigSP _indexschema; DocumentDBConfig::AttributesConfigSP _attributes; DocumentDBConfig::SummaryConfigSP _summary; @@ -54,6 +55,10 @@ public: _rankingConstants = rankingConstants_in; return *this; } + DocumentDBConfigBuilder &onnxModels(const DocumentDBConfig::OnnxModels::SP &onnxModels_in) { + _onnxModels = onnxModels_in; + return *this; + } DocumentDBConfigBuilder &importedFields(const DocumentDBConfig::ImportedFieldsConfigSP &importedFields_in) { _importedFields = importedFields_in; return *this; diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp index 7a200a46ab2..826984832f6 100644 --- a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp +++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp @@ -58,9 +58,7 @@ struct OnnxFeatureTest : ::testing::Test { indexEnv.getProperties().add(expr_name, expr); } void add_onnx(const vespalib::string &name, const vespalib::string &file) { - vespalib::string feature_name = onnx_feature(name); - vespalib::string file_name = feature_name + ".fileref"; - indexEnv.getProperties().add(file_name, file); + indexEnv.addOnnxModel(name, file); } void compile(const vespalib::string &seed) { resolver->addSeed(seed); diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp index 7433021b9b6..b24392ce629 100644 --- a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp +++ b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp @@ -66,15 +66,14 @@ OnnxBlueprint::setup(const IIndexEnvironment &env, auto optimize = (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP) ? Onnx::Optimize::DISABLE : Onnx::Optimize::ENABLE; - - // Note: Using the fileref property with the model name as - // fallback to get a file name. This needs to be replaced with an - // actual file reference obtained through config when available. - vespalib::string file_name = env.getProperties().lookup(getName(), "fileref").get(params[0].getValue()); + auto file_name = env.getOnnxModelFullPath(params[0].getValue()); + if (!file_name.has_value()) { + return fail("no model with name '%s' found", params[0].getValue().c_str()); + } try { - _model = std::make_unique<Onnx>(file_name, optimize); + _model = std::make_unique<Onnx>(file_name.value(), optimize); } catch (std::exception &ex) { - return fail("Model setup failed: %s", ex.what()); + return fail("model setup failed: %s", ex.what()); } Onnx::WirePlanner planner; for (size_t i = 0; i < _model->inputs().size(); ++i) { diff --git a/searchlib/src/vespa/searchlib/fef/iindexenvironment.h b/searchlib/src/vespa/searchlib/fef/iindexenvironment.h index bdeead3e852..26e88a98033 100644 --- a/searchlib/src/vespa/searchlib/fef/iindexenvironment.h +++ b/searchlib/src/vespa/searchlib/fef/iindexenvironment.h @@ -3,6 +3,7 @@ #pragma once #include <vespa/vespalib/stllike/string.h> +#include <optional> namespace vespalib::eval { struct ConstantValue; } @@ -120,6 +121,11 @@ public: */ virtual std::unique_ptr<vespalib::eval::ConstantValue> getConstantValue(const vespalib::string &name) const = 0; + /** + * Get the full path of the file containing the given onnx model + **/ + virtual std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const = 0; + virtual uint32_t getDistributionKey() const = 0; /** diff --git a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp index e998e4d18bd..6e2e0b88fbb 100644 --- a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp +++ b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.cpp @@ -54,4 +54,21 @@ IndexEnvironment::addConstantValue(const vespalib::string &name, (void) insertRes; } +std::optional<vespalib::string> +IndexEnvironment::getOnnxModelFullPath(const vespalib::string &name) const +{ + auto pos = _models.find(name); + if (pos != _models.end()) { + return pos->second; + } + return std::nullopt; +} + +void +IndexEnvironment::addOnnxModel(const vespalib::string &name, const vespalib::string &path) +{ + _models[name] = path; +} + + } diff --git a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h index d84cebc7f52..6602d9f8ee9 100644 --- a/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h +++ b/searchlib/src/vespa/searchlib/fef/test/indexenvironment.h @@ -47,6 +47,7 @@ public: }; using ConstantsMap = std::map<vespalib::string, Constant>; + using ModelMap = std::map<vespalib::string, vespalib::string>; IndexEnvironment(); ~IndexEnvironment(); @@ -83,6 +84,9 @@ public: vespalib::eval::ValueType type, std::unique_ptr<vespalib::eval::Value> value); + std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &name) const override; + void addOnnxModel(const vespalib::string &name, const vespalib::string &path); + private: IndexEnvironment(const IndexEnvironment &); // hide IndexEnvironment & operator=(const IndexEnvironment &); // hide @@ -93,6 +97,7 @@ private: AttributeMap _attrMap; TableManager _tableMan; ConstantsMap _constants; + ModelMap _models; }; } diff --git a/streamingvisitors/src/vespa/searchvisitor/indexenvironment.h b/streamingvisitors/src/vespa/searchvisitor/indexenvironment.h index ac6836b08c5..3bbfb0b23f9 100644 --- a/streamingvisitors/src/vespa/searchvisitor/indexenvironment.h +++ b/streamingvisitors/src/vespa/searchvisitor/indexenvironment.h @@ -73,6 +73,10 @@ public: return vespalib::eval::ConstantValue::UP(); } + std::optional<vespalib::string> getOnnxModelFullPath(const vespalib::string &) const override { + return std::nullopt; + } + bool addField(const vespalib::string & name, bool isAttribute); search::fef::Properties & getProperties() { return _properties; } |