From 4a126bdd16323226411561b969e581af90260692 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Thu, 20 May 2021 11:29:12 +0200 Subject: Evaluate ONNX models in model-evaluation with ONNX RT --- .../src/main/java/com/yahoo/searchdefinition/OnnxModels.java | 6 +++++- .../main/java/com/yahoo/searchdefinition/RankProfile.java | 2 +- .../com/yahoo/searchdefinition/derived/RankProfileList.java | 9 ++++++++- .../vespa/model/container/ApplicationContainerCluster.java | 7 +++++++ .../vespa/model/container/ContainerModelEvaluation.java | 12 +++++++++++- .../java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java | 7 ++++++- 6 files changed, 38 insertions(+), 5 deletions(-) (limited to 'config-model/src') diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java index 1cc33664e8c..60733a4f5ba 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java @@ -9,7 +9,7 @@ import java.util.HashMap; import java.util.Map; /** - * ONNX models tied to a search definition. + * ONNX models tied to a search definition or global. * * @author lesters */ @@ -23,6 +23,10 @@ public class OnnxModels { models.put(name, model); } + public void add(Map models) { + models.values().forEach(this::add); + } + public OnnxModel get(String name) { return models.get(name); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index b460752d7bd..be246a143b2 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -161,7 +161,7 @@ public class RankProfile implements Cloneable { return search != null ? search.rankingConstants() : model.rankingConstants(); } - private Map onnxModels() { + public Map onnxModels() { return search != null ? search.onnxModels().asMap() : Collections.emptyMap(); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 22a32c8fd65..42fa1df802b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -57,8 +57,8 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ ModelContext.Properties deployProperties) { setName(search == null ? "default" : search.getName()); this.rankingConstants = rankingConstants; - deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields, deployProperties); this.onnxModels = search == null ? new OnnxModels() : search.onnxModels(); // as ONNX models come from parsing rank expressions + deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields, deployProperties); } private void deriveRankProfiles(RankProfileRegistry rankProfileRegistry, @@ -75,6 +75,9 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ for (RankProfile rank : rankProfileRegistry.rankProfilesOf(search)) { if (search != null && "default".equals(rank.getName())) continue; + if (search == null) { + this.onnxModels.add(rank.onnxModels()); + } RawRankProfile rawRank = new RawRankProfile(rank, queryProfiles, importedModels, attributeFields, deployProperties); rankProfiles.put(rawRank.getName(), rawRank); @@ -94,6 +97,10 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ rankingConstants.sendTo(services); } + public void sendOnnxModelsTo(Collection services) { + onnxModels.sendTo(services); + } + @Override public String getDerivedName() { return "rank-profiles"; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java index f0c62664988..4e78f44d0fe 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java @@ -23,6 +23,7 @@ import com.yahoo.jdisc.http.ServletPathsConfig; import com.yahoo.osgi.provider.model.ComponentModel; import com.yahoo.search.config.QrStartConfig; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyContainer; import com.yahoo.vespa.model.container.component.BindingPattern; @@ -56,6 +57,7 @@ public final class ApplicationContainerCluster extends ContainerCluster containers) { rankProfileList.sendConstantsTo(containers); + rankProfileList.sendOnnxModelsTo(containers); } @Override @@ -47,6 +52,11 @@ public class ContainerModelEvaluation implements RankProfilesConfig.Producer, Ra rankProfileList.getConfig(builder); } + @Override + public void getConfig(OnnxModelsConfig.Builder builder) { + rankProfileList.getConfig(builder); + } + public static Handler getHandler() { Handler handler = new Handler<>(new ComponentModel(REST_HANDLER_NAME, null, BUNDLE_NAME)); handler.addServerBindings( diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index fc6a4ee2783..d0196ace766 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -15,6 +15,7 @@ import com.yahoo.path.Path; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.container.ApplicationContainerCluster; @@ -95,6 +96,10 @@ public class ModelEvaluationTest { cluster.getConfig(cb); RankingConstantsConfig constantsConfig = new RankingConstantsConfig(cb); + OnnxModelsConfig.Builder ob = new OnnxModelsConfig.Builder(); + cluster.getConfig(ob); + OnnxModelsConfig onnxModelsConfig = new OnnxModelsConfig(ob); + assertEquals(4, config.rankprofile().size()); Set modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet()); assertTrue(modelNames.contains("xgboost_2_2")); @@ -109,7 +114,7 @@ public class ModelEvaluationTest { assertEquals(profile, sb.toString()); ModelsEvaluator evaluator = new ModelsEvaluator(new ToleratingMissingConstantFilesRankProfilesConfigImporter(MockFileAcquirer.returnFile(null)) - .importFrom(config, constantsConfig)); + .importFrom(config, constantsConfig, onnxModelsConfig)); assertEquals(4, evaluator.models().size()); -- cgit v1.2.3