aboutsummaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-05-20 11:29:12 +0200
committerLester Solbakken <lesters@oath.com>2021-05-20 11:29:12 +0200
commit4a126bdd16323226411561b969e581af90260692 (patch)
tree50bd5318dd8e0f174ff26a41a786042b787c9001 /config-model
parentfc0711f7870b55ea77d18d87ec3e70b75e0de2e0 (diff)
Evaluate ONNX models in model-evaluation with ONNX RT
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java6
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java9
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java7
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java12
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java7
6 files changed, 38 insertions, 5 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
index 1cc33664e8c..60733a4f5ba 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
@@ -9,7 +9,7 @@ import java.util.HashMap;
import java.util.Map;
/**
- * ONNX models tied to a search definition.
+ * ONNX models tied to a search definition or global.
*
* @author lesters
*/
@@ -23,6 +23,10 @@ public class OnnxModels {
models.put(name, model);
}
+ public void add(Map<String, OnnxModel> models) {
+ models.values().forEach(this::add);
+ }
+
public OnnxModel get(String name) {
return models.get(name);
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index b460752d7bd..be246a143b2 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -161,7 +161,7 @@ public class RankProfile implements Cloneable {
return search != null ? search.rankingConstants() : model.rankingConstants();
}
- private Map<String, OnnxModel> onnxModels() {
+ public Map<String, OnnxModel> onnxModels() {
return search != null ? search.onnxModels().asMap() : Collections.emptyMap();
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index 22a32c8fd65..42fa1df802b 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -57,8 +57,8 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
ModelContext.Properties deployProperties) {
setName(search == null ? "default" : search.getName());
this.rankingConstants = rankingConstants;
- deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields, deployProperties);
this.onnxModels = search == null ? new OnnxModels() : search.onnxModels(); // as ONNX models come from parsing rank expressions
+ deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields, deployProperties);
}
private void deriveRankProfiles(RankProfileRegistry rankProfileRegistry,
@@ -75,6 +75,9 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
for (RankProfile rank : rankProfileRegistry.rankProfilesOf(search)) {
if (search != null && "default".equals(rank.getName())) continue;
+ if (search == null) {
+ this.onnxModels.add(rank.onnxModels());
+ }
RawRankProfile rawRank = new RawRankProfile(rank, queryProfiles, importedModels, attributeFields, deployProperties);
rankProfiles.put(rawRank.getName(), rawRank);
@@ -94,6 +97,10 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
rankingConstants.sendTo(services);
}
+ public void sendOnnxModelsTo(Collection<? extends AbstractService> services) {
+ onnxModels.sendTo(services);
+ }
+
@Override
public String getDerivedName() { return "rank-profiles"; }
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java
index f0c62664988..4e78f44d0fe 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java
@@ -23,6 +23,7 @@ import com.yahoo.jdisc.http.ServletPathsConfig;
import com.yahoo.osgi.provider.model.ComponentModel;
import com.yahoo.search.config.QrStartConfig;
import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyContainer;
import com.yahoo.vespa.model.container.component.BindingPattern;
@@ -56,6 +57,7 @@ public final class ApplicationContainerCluster extends ContainerCluster<Applicat
QrStartConfig.Producer,
RankProfilesConfig.Producer,
RankingConstantsConfig.Producer,
+ OnnxModelsConfig.Producer,
ServletPathsConfig.Producer,
ContainerMbusConfig.Producer,
MetricsProxyApiConfig.Producer,
@@ -227,6 +229,11 @@ public final class ApplicationContainerCluster extends ContainerCluster<Applicat
}
@Override
+ public void getConfig(OnnxModelsConfig.Builder builder) {
+ if (modelEvaluation != null) modelEvaluation.getConfig(builder);
+ }
+
+ @Override
public void getConfig(ContainerMbusConfig.Builder builder) {
if (mbusParams != null) {
if (mbusParams.maxConcurrentFactor != null)
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
index 72f1921e6a2..510d2fe3d99 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
@@ -5,6 +5,7 @@ import ai.vespa.models.evaluation.ModelsEvaluator;
import com.yahoo.osgi.provider.model.ComponentModel;
import com.yahoo.searchdefinition.derived.RankProfileList;
import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.model.container.component.Handler;
import com.yahoo.vespa.model.container.component.SystemBindingPattern;
@@ -17,7 +18,10 @@ import java.util.Objects;
*
* @author bratseth
*/
-public class ContainerModelEvaluation implements RankProfilesConfig.Producer, RankingConstantsConfig.Producer {
+public class ContainerModelEvaluation implements RankProfilesConfig.Producer,
+ RankingConstantsConfig.Producer,
+ OnnxModelsConfig.Producer
+{
private final static String BUNDLE_NAME = "model-evaluation";
private final static String EVALUATOR_NAME = ModelsEvaluator.class.getName();
@@ -35,6 +39,7 @@ public class ContainerModelEvaluation implements RankProfilesConfig.Producer, Ra
public void prepare(List<ApplicationContainer> containers) {
rankProfileList.sendConstantsTo(containers);
+ rankProfileList.sendOnnxModelsTo(containers);
}
@Override
@@ -47,6 +52,11 @@ public class ContainerModelEvaluation implements RankProfilesConfig.Producer, Ra
rankProfileList.getConfig(builder);
}
+ @Override
+ public void getConfig(OnnxModelsConfig.Builder builder) {
+ rankProfileList.getConfig(builder);
+ }
+
public static Handler<?> getHandler() {
Handler<?> handler = new Handler<>(new ComponentModel(REST_HANDLER_NAME, null, BUNDLE_NAME));
handler.addServerBindings(
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index fc6a4ee2783..d0196ace766 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -15,6 +15,7 @@ import com.yahoo.path.Path;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
@@ -95,6 +96,10 @@ public class ModelEvaluationTest {
cluster.getConfig(cb);
RankingConstantsConfig constantsConfig = new RankingConstantsConfig(cb);
+ OnnxModelsConfig.Builder ob = new OnnxModelsConfig.Builder();
+ cluster.getConfig(ob);
+ OnnxModelsConfig onnxModelsConfig = new OnnxModelsConfig(ob);
+
assertEquals(4, config.rankprofile().size());
Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet());
assertTrue(modelNames.contains("xgboost_2_2"));
@@ -109,7 +114,7 @@ public class ModelEvaluationTest {
assertEquals(profile, sb.toString());
ModelsEvaluator evaluator = new ModelsEvaluator(new ToleratingMissingConstantFilesRankProfilesConfigImporter(MockFileAcquirer.returnFile(null))
- .importFrom(config, constantsConfig));
+ .importFrom(config, constantsConfig, onnxModelsConfig));
assertEquals(4, evaluator.models().size());