diff options
Diffstat (limited to 'config-model/src')
6 files changed, 38 insertions, 5 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java index 1cc33664e8c..60733a4f5ba 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java @@ -9,7 +9,7 @@ import java.util.HashMap; import java.util.Map; /** - * ONNX models tied to a search definition. + * ONNX models tied to a search definition or global. * * @author lesters */ @@ -23,6 +23,10 @@ public class OnnxModels { models.put(name, model); } + public void add(Map<String, OnnxModel> models) { + models.values().forEach(this::add); + } + public OnnxModel get(String name) { return models.get(name); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index b460752d7bd..be246a143b2 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -161,7 +161,7 @@ public class RankProfile implements Cloneable { return search != null ? search.rankingConstants() : model.rankingConstants(); } - private Map<String, OnnxModel> onnxModels() { + public Map<String, OnnxModel> onnxModels() { return search != null ? search.onnxModels().asMap() : Collections.emptyMap(); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 22a32c8fd65..42fa1df802b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -57,8 +57,8 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ ModelContext.Properties deployProperties) { setName(search == null ? "default" : search.getName()); this.rankingConstants = rankingConstants; - deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields, deployProperties); this.onnxModels = search == null ? new OnnxModels() : search.onnxModels(); // as ONNX models come from parsing rank expressions + deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields, deployProperties); } private void deriveRankProfiles(RankProfileRegistry rankProfileRegistry, @@ -75,6 +75,9 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ for (RankProfile rank : rankProfileRegistry.rankProfilesOf(search)) { if (search != null && "default".equals(rank.getName())) continue; + if (search == null) { + this.onnxModels.add(rank.onnxModels()); + } RawRankProfile rawRank = new RawRankProfile(rank, queryProfiles, importedModels, attributeFields, deployProperties); rankProfiles.put(rawRank.getName(), rawRank); @@ -94,6 +97,10 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ rankingConstants.sendTo(services); } + public void sendOnnxModelsTo(Collection<? extends AbstractService> services) { + onnxModels.sendTo(services); + } + @Override public String getDerivedName() { return "rank-profiles"; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java index f0c62664988..4e78f44d0fe 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java @@ -23,6 +23,7 @@ import com.yahoo.jdisc.http.ServletPathsConfig; import com.yahoo.osgi.provider.model.ComponentModel; import com.yahoo.search.config.QrStartConfig; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.model.admin.metricsproxy.MetricsProxyContainer; import com.yahoo.vespa.model.container.component.BindingPattern; @@ -56,6 +57,7 @@ public final class ApplicationContainerCluster extends ContainerCluster<Applicat QrStartConfig.Producer, RankProfilesConfig.Producer, RankingConstantsConfig.Producer, + OnnxModelsConfig.Producer, ServletPathsConfig.Producer, ContainerMbusConfig.Producer, MetricsProxyApiConfig.Producer, @@ -227,6 +229,11 @@ public final class ApplicationContainerCluster extends ContainerCluster<Applicat } @Override + public void getConfig(OnnxModelsConfig.Builder builder) { + if (modelEvaluation != null) modelEvaluation.getConfig(builder); + } + + @Override public void getConfig(ContainerMbusConfig.Builder builder) { if (mbusParams != null) { if (mbusParams.maxConcurrentFactor != null) diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java index 72f1921e6a2..510d2fe3d99 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java @@ -5,6 +5,7 @@ import ai.vespa.models.evaluation.ModelsEvaluator; import com.yahoo.osgi.provider.model.ComponentModel; import com.yahoo.searchdefinition.derived.RankProfileList; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.model.container.component.Handler; import com.yahoo.vespa.model.container.component.SystemBindingPattern; @@ -17,7 +18,10 @@ import java.util.Objects; * * @author bratseth */ -public class ContainerModelEvaluation implements RankProfilesConfig.Producer, RankingConstantsConfig.Producer { +public class ContainerModelEvaluation implements RankProfilesConfig.Producer, + RankingConstantsConfig.Producer, + OnnxModelsConfig.Producer +{ private final static String BUNDLE_NAME = "model-evaluation"; private final static String EVALUATOR_NAME = ModelsEvaluator.class.getName(); @@ -35,6 +39,7 @@ public class ContainerModelEvaluation implements RankProfilesConfig.Producer, Ra public void prepare(List<ApplicationContainer> containers) { rankProfileList.sendConstantsTo(containers); + rankProfileList.sendOnnxModelsTo(containers); } @Override @@ -47,6 +52,11 @@ public class ContainerModelEvaluation implements RankProfilesConfig.Producer, Ra rankProfileList.getConfig(builder); } + @Override + public void getConfig(OnnxModelsConfig.Builder builder) { + rankProfileList.getConfig(builder); + } + public static Handler<?> getHandler() { Handler<?> handler = new Handler<>(new ComponentModel(REST_HANDLER_NAME, null, BUNDLE_NAME)); handler.addServerBindings( diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index fc6a4ee2783..d0196ace766 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -15,6 +15,7 @@ import com.yahoo.path.Path; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.container.ApplicationContainerCluster; @@ -95,6 +96,10 @@ public class ModelEvaluationTest { cluster.getConfig(cb); RankingConstantsConfig constantsConfig = new RankingConstantsConfig(cb); + OnnxModelsConfig.Builder ob = new OnnxModelsConfig.Builder(); + cluster.getConfig(ob); + OnnxModelsConfig onnxModelsConfig = new OnnxModelsConfig(ob); + assertEquals(4, config.rankprofile().size()); Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet()); assertTrue(modelNames.contains("xgboost_2_2")); @@ -109,7 +114,7 @@ public class ModelEvaluationTest { assertEquals(profile, sb.toString()); ModelsEvaluator evaluator = new ModelsEvaluator(new ToleratingMissingConstantFilesRankProfilesConfigImporter(MockFileAcquirer.returnFile(null)) - .importFrom(config, constantsConfig)); + .importFrom(config, constantsConfig, onnxModelsConfig)); assertEquals(4, evaluator.models().size()); |