diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/derived')
5 files changed, 138 insertions, 92 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java index 50029613b2e..c7b9b94a4b2 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java @@ -82,8 +82,7 @@ public class DerivedConfiguration implements AttributesConfig.Producer { summaries = new Summaries(schema, deployState.getDeployLogger(), deployState.getProperties().featureFlags()); summaryMap = new SummaryMap(schema); juniperrc = new Juniperrc(schema); - rankProfileList = new RankProfileList(schema, schema.constants(), schema.rankExpressionFiles(), - schema.onnxModels(), attributeFields, deployState); + rankProfileList = new RankProfileList(schema, schema.rankExpressionFiles(), attributeFields, deployState); indexingScript = new IndexingScript(schema); indexInfo = new IndexInfo(schema); schemaInfo = new SchemaInfo(schema, deployState.rankProfileRegistry(), summaries, summaryMap); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedConstants.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedConstants.java index 8de86beacdb..433bfb108d6 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedConstants.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedConstants.java @@ -5,12 +5,12 @@ import com.yahoo.config.application.api.FileRegistry; import com.yahoo.searchdefinition.DistributableResource; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.tensor.TensorType; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import java.util.Collection; import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; -import java.util.function.Function; /** * Constant values for ranking/model execution tied to a rank profile, @@ -41,6 +41,15 @@ public class FileDistributedConstants { /** Returns a read-only map of the constants in this indexed by name. */ public Map<String, DistributableConstant> asMap() { return constants; } + public void getConfig(RankingConstantsConfig.Builder builder) { + for (var constant : constants.values()) { + builder.constant(new RankingConstantsConfig.Constant.Builder() + .name(constant.getName()) + .fileref(constant.getFileReference()) + .type(constant.getType())); + } + } + public static class DistributableConstant extends DistributableResource { private final TensorType tensorType; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java new file mode 100644 index 00000000000..a1310d6c18e --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java @@ -0,0 +1,63 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.derived; + +import com.yahoo.config.application.api.FileRegistry; +import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.Schema; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Logger; + +/** + * ONNX models distributed as files. + * + * @author bratseth + */ +public class FileDistributedOnnxModels { + + private static final Logger log = Logger.getLogger(FileDistributedOnnxModels.class.getName()); + + private final Map<String, OnnxModel> models; + + public FileDistributedOnnxModels(FileRegistry fileRegistry, Collection<OnnxModel> models) { + Map<String, OnnxModel> distributableModels = new LinkedHashMap<>(); + for (var model : models) { + model.validate(); + model.register(fileRegistry); + distributableModels.put(model.getName(), model); + } + this.models = Collections.unmodifiableMap(distributableModels); + } + + public Map<String, OnnxModel> asMap() { return models; } + + public void getConfig(OnnxModelsConfig.Builder builder) { + for (OnnxModel model : models.values()) { + if ("".equals(model.getFileReference())) + log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way + else { + OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder(); + modelBuilder.dry_run_on_setup(true); + modelBuilder.name(model.getName()); + modelBuilder.fileref(model.getFileReference()); + model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source))); + model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as))); + if (model.getStatelessExecutionMode().isPresent()) + modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get()); + if (model.getStatelessInterOpThreads().isPresent()) + modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get()); + if (model.getStatelessIntraOpThreads().isPresent()) + modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get()); + + builder.model(modelBuilder); + } + } + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index d415f7d0d6e..081450275d1 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -2,13 +2,11 @@ package com.yahoo.searchdefinition.derived; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; -import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.api.ModelContext; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.searchdefinition.OnnxModel; -import com.yahoo.searchdefinition.OnnxModels; import com.yahoo.searchdefinition.LargeRankExpressions; +import com.yahoo.searchdefinition.OnnxModel; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.vespa.config.search.RankProfilesConfig; @@ -24,13 +22,10 @@ import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; -import java.util.logging.Level; -import java.util.logging.Logger; /** * The derived rank profiles of a schema @@ -39,19 +34,17 @@ import java.util.logging.Logger; */ public class RankProfileList extends Derived implements RankProfilesConfig.Producer { - private static final Logger log = Logger.getLogger(RankProfileList.class.getName()); - private final Map<String, RawRankProfile> rankProfiles; private final FileDistributedConstants constants; private final LargeRankExpressions largeRankExpressions; - private final OnnxModels onnxModels; + private final FileDistributedOnnxModels onnxModels; public static final RankProfileList empty = new RankProfileList(); private RankProfileList() { constants = new FileDistributedConstants(null, List.of()); largeRankExpressions = new LargeRankExpressions(null); - onnxModels = new OnnxModels(null, Optional.empty()); + onnxModels = new FileDistributedOnnxModels(null, List.of()); rankProfiles = Map.of(); } @@ -62,46 +55,16 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ * @param attributeFields the attribute fields to create a ranking for */ public RankProfileList(Schema schema, - Map<Reference, RankProfile.Constant> constantsFromSchema, LargeRankExpressions largeRankExpressions, - OnnxModels onnxModels, AttributeFields attributeFields, DeployState deployState) { setName(schema == null ? "default" : schema.getName()); this.largeRankExpressions = largeRankExpressions; - this.onnxModels = onnxModels; // as ONNX models come from parsing rank expressions this.rankProfiles = deriveRankProfiles(schema, attributeFields, deployState); - this.constants = deriveFileDistributedConstants(schema, constantsFromSchema, rankProfiles.values(), deployState); - } - - private static FileDistributedConstants deriveFileDistributedConstants(Schema schema, - Map<Reference, RankProfile.Constant> constantsFromSchema, - Collection<RawRankProfile> rankProfiles, - DeployState deployState) { - Map<Reference, RankProfile.Constant> allFileConstants = new HashMap<>(); - addFileConstants(constantsFromSchema.values(), allFileConstants, schema != null ? schema.toString() : "[global]"); - for (var profile : rankProfiles) - addFileConstants(profile.compiled().constants().values(), allFileConstants, profile.toString()); - return new FileDistributedConstants(deployState.getFileRegistry(), allFileConstants.values()); - } - - private static void addFileConstants(Collection<RankProfile.Constant> source, - Map<Reference, RankProfile.Constant> destination, - String sourceName) { - for (var constant : source) { - if (constant.valuePath().isEmpty()) continue; - var existing = destination.get(constant.name()); - if ( existing != null && ! constant.equals(existing)) { - throw new IllegalArgumentException("Duplicate constants: " + sourceName + " have " + constant + - ", but we already have " + existing + - ": Value reference constants must be unique across all rank profiles/models"); - } - destination.put(constant.name(), constant); - } + this.constants = deriveFileDistributedConstants(schema, rankProfiles.values(), deployState); + this.onnxModels = deriveFileDistributedOnnxModels(schema, rankProfiles.values(), deployState); } - public FileDistributedConstants constants() { return constants; } - private boolean areDependenciesReady(RankProfile rank, RankProfileRegistry registry, Set<String> processedProfiles) { return rank.inheritedNames().isEmpty() || processedProfiles.containsAll(rank.inheritedNames()) || @@ -134,7 +97,6 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ rawRankProfiles.putAll(processRankProfiles(ready, deployState.getQueryProfiles().getRegistry(), deployState.getImportedModels(), - schema, attributeFields, deployState.getProperties(), deployState.getExecutor())); @@ -143,20 +105,15 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ return rawRankProfiles; } - private Map<String, RawRankProfile> processRankProfiles(List<RankProfile> ready, + private Map<String, RawRankProfile> processRankProfiles(List<RankProfile> profiles, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, - Schema schema, AttributeFields attributeFields, ModelContext.Properties deployProperties, ExecutorService executor) { Map<String, Future<RawRankProfile>> futureRawRankProfiles = new LinkedHashMap<>(); - for (RankProfile rank : ready) { - if (schema == null) { - onnxModels.add(rank.onnxModels()); - } - - futureRawRankProfiles.put(rank.name(), executor.submit(() -> new RawRankProfile(rank, largeRankExpressions, queryProfiles, importedModels, + for (RankProfile profile : profiles) { + futureRawRankProfiles.put(profile.name(), executor.submit(() -> new RawRankProfile(profile, largeRankExpressions, queryProfiles, importedModels, attributeFields, deployProperties))); } try { @@ -171,19 +128,63 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ } } - public OnnxModels getOnnxModels() { - return onnxModels; + private static FileDistributedConstants deriveFileDistributedConstants(Schema schema, + Collection<RawRankProfile> rankProfiles, + DeployState deployState) { + Map<Reference, RankProfile.Constant> allFileConstants = new HashMap<>(); + addFileConstants(schema != null ? schema.constants().values() : List.of(), + allFileConstants, + schema != null ? schema.toString() : "[global]"); + for (var profile : rankProfiles) + addFileConstants(profile.compiled().constants().values(), allFileConstants, profile.toString()); + return new FileDistributedConstants(deployState.getFileRegistry(), allFileConstants.values()); } - public Map<String, RawRankProfile> getRankProfiles() { - return rankProfiles; + private static void addFileConstants(Collection<RankProfile.Constant> source, + Map<Reference, RankProfile.Constant> destination, + String sourceName) { + for (var constant : source) { + if (constant.valuePath().isEmpty()) continue; + var existing = destination.get(constant.name()); + if ( existing != null && ! constant.equals(existing)) { + throw new IllegalArgumentException("Duplicate constants: " + sourceName + " have " + constant + + ", but we already have " + existing + + ": Value reference constants must be unique across all rank profiles/models"); + } + destination.put(constant.name(), constant); + } } - /** Returns the raw rank profile with the given name, or null if it is not present */ - public RawRankProfile getRankProfile(String name) { - return rankProfiles.get(name); + private static FileDistributedOnnxModels deriveFileDistributedOnnxModels(Schema schema, + Collection<RawRankProfile> rankProfiles, + DeployState deployState) { + Map<String, OnnxModel> allModels = new LinkedHashMap<>(); + addOnnxModels(schema != null ? schema.onnxModels().values() : List.of(), + allModels, + schema != null ? schema.toString() : "[global]"); + for (var profile : rankProfiles) + addOnnxModels(profile.compiled().onnxModels().values(), allModels, profile.toString()); + return new FileDistributedOnnxModels(deployState.getFileRegistry(), allModels.values()); } + private static void addOnnxModels(Collection<OnnxModel> source, + Map<String, OnnxModel> destination, + String sourceName) { + for (var model : source) { + var existing = destination.get(model.getName()); + if ( existing != null && ! model.equals(existing)) { + throw new IllegalArgumentException("Duplicate onnx model: " + sourceName + " have " + model + + ", but we already have " + existing + + ": Onnx models must be unique across all rank profiles/models"); + } + destination.put(model.getName(), model); + } + } + + public Map<String, RawRankProfile> getRankProfiles() { return rankProfiles; } + public FileDistributedConstants constants() { return constants; } + public FileDistributedOnnxModels getOnnxModels() { return onnxModels; } + @Override public String getDerivedName() { return "rank-profiles"; } @@ -199,37 +200,11 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ } public void getConfig(RankingConstantsConfig.Builder builder) { - for (var constant : constants.asMap().values()) { - if ("".equals(constant.getFileReference())) - log.warning("Illegal file reference " + constant); // Let tests pass ... we should find a better way - else - builder.constant(new RankingConstantsConfig.Constant.Builder() - .name(constant.getName()) - .fileref(constant.getFileReference()) - .type(constant.getType())); - } + constants.getConfig(builder); } public void getConfig(OnnxModelsConfig.Builder builder) { - for (OnnxModel model : onnxModels.asMap().values()) { - if ("".equals(model.getFileReference())) - log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way - else { - OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder(); - modelBuilder.dry_run_on_setup(true); - modelBuilder.name(model.getName()); - modelBuilder.fileref(model.getFileReference()); - model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source))); - model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as))); - if (model.getStatelessExecutionMode().isPresent()) - modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get()); - if (model.getStatelessInterOpThreads().isPresent()) - modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get()); - if (model.getStatelessIntraOpThreads().isPresent()) - modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get()); - - builder.model(modelBuilder); - } - } + onnxModels.getConfig(builder); } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index 2189d9d97db..dcd5019cf58 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -483,7 +483,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { private void deriveOnnxModelFunctionsAndFeatures(RankProfile rankProfile) { if (rankProfile.schema() == null) return; - if (rankProfile.schema().onnxModels().asMap().isEmpty()) return; + if (rankProfile.onnxModels().isEmpty()) return; replaceOnnxFunctionInputs(rankProfile); replaceImplicitOnnxConfigFeatures(summaryFeatures, rankProfile); replaceImplicitOnnxConfigFeatures(matchFeatures, rankProfile); @@ -492,7 +492,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { private void replaceOnnxFunctionInputs(RankProfile rankProfile) { Set<String> functionNames = rankProfile.getFunctions().keySet(); if (functionNames.isEmpty()) return; - for (OnnxModel onnxModel: rankProfile.schema().onnxModels().asMap().values()) { + for (OnnxModel onnxModel: rankProfile.onnxModels().values()) { for (Map.Entry<String, String> mapping : onnxModel.getInputMap().entrySet()) { String source = mapping.getValue(); if (functionNames.contains(source)) { |