diff options
Diffstat (limited to 'searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp')
-rw-r--r-- | searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp | 35 |
1 files changed, 28 insertions, 7 deletions
diff --git a/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp b/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp index 118cad4d8ef..7043c450047 100644 --- a/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp +++ b/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp @@ -12,6 +12,7 @@ #include <vespa/eval/tensor/default_tensor_engine.h> #include <vespa/searchcommon/common/schemaconfigurer.h> #include <vespa/searchcore/config/config-ranking-constants.h> +#include <vespa/searchcore/config/config-onnx-models.h> #include <vespa/searchcore/proton/matching/indexenvironment.h> #include <vespa/searchlib/features/setup.h> #include <vespa/searchlib/fef/fef.h> @@ -28,10 +29,12 @@ using config::ConfigSubscriber; using config::IConfigContext; using config::InvalidConfigException; using proton::matching::IConstantValueRepo; +using proton::matching::OnnxModels; using vespa::config::search::AttributesConfig; using vespa::config::search::IndexschemaConfig; using vespa::config::search::RankProfilesConfig; using vespa::config::search::core::RankingConstantsConfig; +using vespa::config::search::core::OnnxModelsConfig; using vespalib::eval::ConstantValue; using vespalib::eval::TensorSpec; using vespalib::eval::ValueType; @@ -39,17 +42,30 @@ using vespalib::tensor::DefaultTensorEngine; using vespalib::eval::SimpleConstantValue; using vespalib::eval::BadConstantValue; +OnnxModels make_models(const OnnxModelsConfig &modelsCfg) { + OnnxModels::Vector model_list; + for (const auto &entry: modelsCfg.model) { + // TODO(havardpe): resolve model path + vespalib::string model_path = entry.name; + model_path += ".onnx"; + model_list.emplace_back(entry.name, model_path); + } + return OnnxModels(model_list); +} + class App : public FastOS_Application { public: bool verify(const search::index::Schema &schema, const search::fef::Properties &props, - const IConstantValueRepo &repo); + const IConstantValueRepo &repo, + OnnxModels models); bool verifyConfig(const RankProfilesConfig &rankCfg, const IndexschemaConfig &schemaCfg, const AttributesConfig &attributeCfg, - const RankingConstantsConfig &constantsCfg); + const RankingConstantsConfig &constantsCfg, + const OnnxModelsConfig &modelsCfg); int usage(); int Main() override; @@ -77,9 +93,10 @@ struct DummyConstantValueRepo : IConstantValueRepo { bool App::verify(const search::index::Schema &schema, const search::fef::Properties &props, - const IConstantValueRepo &repo) + const IConstantValueRepo &repo, + OnnxModels models) { - proton::matching::IndexEnvironment indexEnv(0, schema, props, repo); + proton::matching::IndexEnvironment indexEnv(0, schema, props, repo, models); search::fef::BlueprintFactory factory; search::features::setup_search_features(factory); search::fef::test::setup_fef_test_plugin(factory); @@ -106,13 +123,15 @@ bool App::verifyConfig(const RankProfilesConfig &rankCfg, const IndexschemaConfig &schemaCfg, const AttributesConfig &attributeCfg, - const RankingConstantsConfig &constantsCfg) + const RankingConstantsConfig &constantsCfg, + const OnnxModelsConfig &modelsCfg) { bool ok = true; search::index::Schema schema; search::index::SchemaBuilder::build(schemaCfg, schema); search::index::SchemaBuilder::build(attributeCfg, schema); DummyConstantValueRepo repo(constantsCfg); + auto models = make_models(modelsCfg); for(size_t i = 0; i < rankCfg.rankprofile.size(); i++) { search::fef::Properties properties; const RankProfilesConfig::Rankprofile &profile = rankCfg.rankprofile[i]; @@ -120,7 +139,7 @@ App::verifyConfig(const RankProfilesConfig &rankCfg, properties.add(profile.fef.property[j].name, profile.fef.property[j].value); } - if (verify(schema, properties, repo)) { + if (verify(schema, properties, repo, models)) { LOG(info, "rank profile '%s': pass", profile.name.c_str()); } else { LOG(error, "rank profile '%s': FAIL", profile.name.c_str()); @@ -157,12 +176,14 @@ App::Main() ConfigHandle<AttributesConfig>::UP attributesHandle = subscriber.subscribe<AttributesConfig>(cfgId); ConfigHandle<IndexschemaConfig>::UP schemaHandle = subscriber.subscribe<IndexschemaConfig>(cfgId); ConfigHandle<RankingConstantsConfig>::UP constantsHandle = subscriber.subscribe<RankingConstantsConfig>(cfgId); + ConfigHandle<OnnxModelsConfig>::UP modelsHandle = subscriber.subscribe<OnnxModelsConfig>(cfgId); subscriber.nextConfig(); ok = verifyConfig(*rankHandle->getConfig(), *schemaHandle->getConfig(), *attributesHandle->getConfig(), - *constantsHandle->getConfig()); + *constantsHandle->getConfig(), + *modelsHandle->getConfig()); } catch (ConfigRuntimeException & e) { LOG(error, "Unable to subscribe to config: %s", e.getMessage().c_str()); } catch (InvalidConfigException & e) { |