summaryrefslogtreecommitdiffstats
path: root/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp')
-rw-r--r--searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp35
1 files changed, 28 insertions, 7 deletions
diff --git a/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp b/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp
index 118cad4d8ef..7043c450047 100644
--- a/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp
+++ b/searchcore/src/apps/verify_ranksetup/verify_ranksetup.cpp
@@ -12,6 +12,7 @@
#include <vespa/eval/tensor/default_tensor_engine.h>
#include <vespa/searchcommon/common/schemaconfigurer.h>
#include <vespa/searchcore/config/config-ranking-constants.h>
+#include <vespa/searchcore/config/config-onnx-models.h>
#include <vespa/searchcore/proton/matching/indexenvironment.h>
#include <vespa/searchlib/features/setup.h>
#include <vespa/searchlib/fef/fef.h>
@@ -28,10 +29,12 @@ using config::ConfigSubscriber;
using config::IConfigContext;
using config::InvalidConfigException;
using proton::matching::IConstantValueRepo;
+using proton::matching::OnnxModels;
using vespa::config::search::AttributesConfig;
using vespa::config::search::IndexschemaConfig;
using vespa::config::search::RankProfilesConfig;
using vespa::config::search::core::RankingConstantsConfig;
+using vespa::config::search::core::OnnxModelsConfig;
using vespalib::eval::ConstantValue;
using vespalib::eval::TensorSpec;
using vespalib::eval::ValueType;
@@ -39,17 +42,30 @@ using vespalib::tensor::DefaultTensorEngine;
using vespalib::eval::SimpleConstantValue;
using vespalib::eval::BadConstantValue;
+OnnxModels make_models(const OnnxModelsConfig &modelsCfg) {
+ OnnxModels::Vector model_list;
+ for (const auto &entry: modelsCfg.model) {
+ // TODO(havardpe): resolve model path
+ vespalib::string model_path = entry.name;
+ model_path += ".onnx";
+ model_list.emplace_back(entry.name, model_path);
+ }
+ return OnnxModels(model_list);
+}
+
class App : public FastOS_Application
{
public:
bool verify(const search::index::Schema &schema,
const search::fef::Properties &props,
- const IConstantValueRepo &repo);
+ const IConstantValueRepo &repo,
+ OnnxModels models);
bool verifyConfig(const RankProfilesConfig &rankCfg,
const IndexschemaConfig &schemaCfg,
const AttributesConfig &attributeCfg,
- const RankingConstantsConfig &constantsCfg);
+ const RankingConstantsConfig &constantsCfg,
+ const OnnxModelsConfig &modelsCfg);
int usage();
int Main() override;
@@ -77,9 +93,10 @@ struct DummyConstantValueRepo : IConstantValueRepo {
bool
App::verify(const search::index::Schema &schema,
const search::fef::Properties &props,
- const IConstantValueRepo &repo)
+ const IConstantValueRepo &repo,
+ OnnxModels models)
{
- proton::matching::IndexEnvironment indexEnv(0, schema, props, repo);
+ proton::matching::IndexEnvironment indexEnv(0, schema, props, repo, models);
search::fef::BlueprintFactory factory;
search::features::setup_search_features(factory);
search::fef::test::setup_fef_test_plugin(factory);
@@ -106,13 +123,15 @@ bool
App::verifyConfig(const RankProfilesConfig &rankCfg,
const IndexschemaConfig &schemaCfg,
const AttributesConfig &attributeCfg,
- const RankingConstantsConfig &constantsCfg)
+ const RankingConstantsConfig &constantsCfg,
+ const OnnxModelsConfig &modelsCfg)
{
bool ok = true;
search::index::Schema schema;
search::index::SchemaBuilder::build(schemaCfg, schema);
search::index::SchemaBuilder::build(attributeCfg, schema);
DummyConstantValueRepo repo(constantsCfg);
+ auto models = make_models(modelsCfg);
for(size_t i = 0; i < rankCfg.rankprofile.size(); i++) {
search::fef::Properties properties;
const RankProfilesConfig::Rankprofile &profile = rankCfg.rankprofile[i];
@@ -120,7 +139,7 @@ App::verifyConfig(const RankProfilesConfig &rankCfg,
properties.add(profile.fef.property[j].name,
profile.fef.property[j].value);
}
- if (verify(schema, properties, repo)) {
+ if (verify(schema, properties, repo, models)) {
LOG(info, "rank profile '%s': pass", profile.name.c_str());
} else {
LOG(error, "rank profile '%s': FAIL", profile.name.c_str());
@@ -157,12 +176,14 @@ App::Main()
ConfigHandle<AttributesConfig>::UP attributesHandle = subscriber.subscribe<AttributesConfig>(cfgId);
ConfigHandle<IndexschemaConfig>::UP schemaHandle = subscriber.subscribe<IndexschemaConfig>(cfgId);
ConfigHandle<RankingConstantsConfig>::UP constantsHandle = subscriber.subscribe<RankingConstantsConfig>(cfgId);
+ ConfigHandle<OnnxModelsConfig>::UP modelsHandle = subscriber.subscribe<OnnxModelsConfig>(cfgId);
subscriber.nextConfig();
ok = verifyConfig(*rankHandle->getConfig(),
*schemaHandle->getConfig(),
*attributesHandle->getConfig(),
- *constantsHandle->getConfig());
+ *constantsHandle->getConfig(),
+ *modelsHandle->getConfig());
} catch (ConfigRuntimeException & e) {
LOG(error, "Unable to subscribe to config: %s", e.getMessage().c_str());
} catch (InvalidConfigException & e) {