diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-05-16 12:35:32 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-05-16 12:35:32 +0200 |
commit | 1d63b5d81c057a8fe99812be22abac38c8195241 (patch) | |
tree | 97bb5db1fb81040c479cc160234948ea66a3100e | |
parent | 640e8893fdb07b6f607d94de5dae24bdf305e705 (diff) |
Add model support for Onnx models in rank profiles
18 files changed, 220 insertions, 207 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSchema.java b/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSchema.java index 1c643292a05..3719313179f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSchema.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSchema.java @@ -4,6 +4,7 @@ package com.yahoo.searchdefinition; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.api.ModelContext; +import com.yahoo.searchdefinition.derived.FileDistributedOnnxModels; import com.yahoo.searchdefinition.document.ImmutableSDField; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; @@ -37,7 +38,7 @@ public interface ImmutableSchema { ModelContext.Properties getDeployProperties(); Map<Reference, RankProfile.Constant> constants(); LargeRankExpressions rankExpressionFiles(); - OnnxModels onnxModels(); + Map<String, OnnxModel> onnxModels(); Stream<ImmutableSDField> allImportedFields(); SDDocumentType getDocument(); ImmutableSDField getField(String name); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java deleted file mode 100644 index c9c12100552..00000000000 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchdefinition; - -import com.yahoo.config.application.api.FileRegistry; - -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; - -/** - * ONNX models tied to a search definition or global. - * - * @author lesters - */ -public class OnnxModels { - - private final FileRegistry fileRegistry; - - /** The schema this belongs to, or empty if it is global */ - private final Optional<Schema> owner; - - private final Map<String, OnnxModel> models = new HashMap<>(); - - public OnnxModels(FileRegistry fileRegistry, Optional<Schema> owner) { - this.fileRegistry = fileRegistry; - this.owner = owner; - } - - public void add(OnnxModel model) { - model.validate(); - model.register(fileRegistry); - String name = model.getName(); - models.put(name, model); - } - - public void add(Map<String, OnnxModel> models) { - models.values().forEach(this::add); - } - - public OnnxModel get(String name) { - var model = models.get(name); - if (model != null) return model; - if (owner.isPresent() && owner.get().inherited().isPresent()) - return owner.get().inherited().get().onnxModels().get(name); - return null; - } - - public boolean has(String name) { - boolean has = models.containsKey(name); - if (has) return true; - if (owner.isPresent() && owner.get().inherited().isPresent()) - return owner.get().inherited().get().onnxModels().has(name); - return false; - } - - public Map<String, OnnxModel> asMap() { - // Shortcuts - if (owner.isEmpty() || owner.get().inherited().isEmpty()) return Collections.unmodifiableMap(models); - if (models.isEmpty()) return owner.get().inherited().get().onnxModels().asMap(); - - var allModels = new HashMap<>(owner.get().inherited().get().onnxModels().asMap()); - allModels.putAll(models); - return Collections.unmodifiableMap(allModels); - } - -} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index fd7e7e57f00..07f3048af04 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -118,6 +118,8 @@ public class RankProfile implements Cloneable { private Map<Reference, Constant> constants = new HashMap<>(); + private Map<String, OnnxModel> onnxModels = new HashMap<>(); + private Set<String> filterFields = new HashSet<>(); private final RankProfileRegistry rankProfileRegistry; @@ -128,8 +130,6 @@ public class RankProfile implements Cloneable { private Boolean strict; - /** Global onnx models not tied to a schema */ - private final OnnxModels onnxModels; private final ApplicationPackage applicationPackage; private final DeployLogger deployLogger; @@ -144,7 +144,6 @@ public class RankProfile implements Cloneable { public RankProfile(String name, Schema schema, RankProfileRegistry rankProfileRegistry) { this.name = Objects.requireNonNull(name, "name cannot be null"); this.schema = Objects.requireNonNull(schema, "schema cannot be null"); - this.onnxModels = null; this.rankProfileRegistry = rankProfileRegistry; this.applicationPackage = schema.applicationPackage(); this.deployLogger = schema.getDeployLogger(); @@ -156,11 +155,10 @@ public class RankProfile implements Cloneable { * @param name the name of the new profile */ public RankProfile(String name, ApplicationPackage applicationPackage, DeployLogger deployLogger, - RankProfileRegistry rankProfileRegistry, OnnxModels onnxModels) { + RankProfileRegistry rankProfileRegistry) { this.name = Objects.requireNonNull(name, "name cannot be null"); this.schema = null; this.rankProfileRegistry = rankProfileRegistry; - this.onnxModels = onnxModels; this.applicationPackage = applicationPackage; this.deployLogger = deployLogger; } @@ -175,10 +173,6 @@ public class RankProfile implements Cloneable { return applicationPackage; } - public Map<String, OnnxModel> onnxModels() { - return schema != null ? schema.onnxModels().asMap() : onnxModels.asMap(); - } - private Stream<ImmutableSDField> allFields() { if (schema == null) return Stream.empty(); if (allFieldsList == null) { @@ -411,6 +405,9 @@ public class RankProfile implements Cloneable { constants.put(constant.name(), constant); } + /** Returns an unmodifiable view of the constants declared in this */ + public Map<Reference, Constant> declaredConstants() { return Collections.unmodifiableMap(constants); } + /** Returns an unmodifiable view of the constants available in this */ public Map<Reference, Constant> constants() { Map<Reference, Constant> allConstants = new HashMap<>(); @@ -430,8 +427,31 @@ public class RankProfile implements Cloneable { return allConstants; } - /** Returns an unmodifiable view of the constants declared in this */ - public Map<Reference, Constant> declaredConstants() { return Collections.unmodifiableMap(constants); } + public void add(OnnxModel model) { + onnxModels.put(model.getName(), model); + } + + /** Returns an unmodifiable map of the onnx models declared in this. */ + public Map<String, OnnxModel> declaredOnnxModels() { return onnxModels; } + + /** Returns an unmodifiable map of the onnx models available in this. */ + public Map<String, OnnxModel> onnxModels() { + Map<String, OnnxModel> allModels = new HashMap<>(); + for (var inheritedProfile : inherited()) { + for (var model : inheritedProfile.onnxModels().values()) { + if (allModels.containsKey(model.getName())) + throw new IllegalArgumentException(model + "' is present in " + + inheritedProfile + " inherited by " + + this + ", but is also present in another profile inherited by it"); + allModels.put(model.getName(), model); + } + } + + if (schema != null) + allModels.putAll(schema.onnxModels()); + allModels.putAll(onnxModels); + return allModels; + } public void addAttributeType(String attributeName, String attributeType) { attributeTypes.addType(attributeName, attributeType); @@ -1064,10 +1084,8 @@ public class RankProfile implements Cloneable { } // Add output types for ONNX models - for (Map.Entry<String, OnnxModel> entry : onnxModels().entrySet()) { - String modelName = entry.getKey(); - OnnxModel model = entry.getValue(); - Arguments args = new Arguments(new ReferenceNode(modelName)); + for (var model : onnxModels().values()) { + Arguments args = new Arguments(new ReferenceNode(model.getName())); Map<String, TensorType> inputTypes = resolveOnnxInputTypes(model, context); TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), inputTypes); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Schema.java b/config-model/src/main/java/com/yahoo/searchdefinition/Schema.java index 147fee05820..8139d46cc0a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/Schema.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/Schema.java @@ -9,6 +9,7 @@ import com.yahoo.config.model.application.provider.BaseDeployLogger; import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.document.DataTypeName; import com.yahoo.document.Field; +import com.yahoo.searchdefinition.derived.FileDistributedOnnxModels; import com.yahoo.searchdefinition.derived.SummaryClass; import com.yahoo.searchdefinition.document.Attribute; import com.yahoo.searchdefinition.document.ImmutableSDField; @@ -27,6 +28,7 @@ import java.io.Reader; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; @@ -90,7 +92,8 @@ public class Schema implements ImmutableSchema { // TODO: Remove on Vespa 9: Should always be in a rank profile private final Map<Reference, RankProfile.Constant> constants = new LinkedHashMap<>(); - private final OnnxModels onnxModels; + // TODO: Remove on Vespa 9: Should always be in a rank profile + private final Map<String, OnnxModel> onnxModels = new LinkedHashMap<>(); /** All imported fields of this (and parent schemas) */ // TODO: Use empty, not optional @@ -150,7 +153,6 @@ public class Schema implements ImmutableSchema { this.properties = properties; this.documentsOnly = documentsOnly; largeRankExpressions = new LargeRankExpressions(fileRegistry); - onnxModels = new OnnxModels(fileRegistry, Optional.of(this)); } /** @@ -230,6 +232,10 @@ public class Schema implements ImmutableSchema { constants.put(constant.name(), constant); } + /** Returns an unmodifiable map of the constants declared in this. */ + public Map<Reference, RankProfile.Constant> declaredConstants() { return constants; } + + /** Returns an unmodifiable map of the constants available in this. */ @Override public Map<Reference, RankProfile.Constant> constants() { if (inherited().isEmpty()) return Collections.unmodifiableMap(constants); @@ -240,8 +246,23 @@ public class Schema implements ImmutableSchema { return allConstants; } + public void add(OnnxModel model) { + onnxModels.put(model.getName(), model); + } + + /** Returns an unmodifiable map of the onnx models declared in this. */ + public Map<String, OnnxModel> declaredOnnxModels() { return onnxModels; } + + /** Returns an unmodifiable map of the onnx models available in this. */ @Override - public OnnxModels onnxModels() { return onnxModels; } + public Map<String, OnnxModel> onnxModels() { + if (inherited().isEmpty()) return Collections.unmodifiableMap(onnxModels); + if (onnxModels.isEmpty()) return inherited().get().onnxModels(); + + Map<String, OnnxModel> allModels = new LinkedHashMap<>(inherited().get().onnxModels()); + allModels.putAll(onnxModels); + return allModels; + } public Optional<TemporaryImportedFields> temporaryImportedFields() { return temporaryImportedFields; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java index 50029613b2e..c7b9b94a4b2 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java @@ -82,8 +82,7 @@ public class DerivedConfiguration implements AttributesConfig.Producer { summaries = new Summaries(schema, deployState.getDeployLogger(), deployState.getProperties().featureFlags()); summaryMap = new SummaryMap(schema); juniperrc = new Juniperrc(schema); - rankProfileList = new RankProfileList(schema, schema.constants(), schema.rankExpressionFiles(), - schema.onnxModels(), attributeFields, deployState); + rankProfileList = new RankProfileList(schema, schema.rankExpressionFiles(), attributeFields, deployState); indexingScript = new IndexingScript(schema); indexInfo = new IndexInfo(schema); schemaInfo = new SchemaInfo(schema, deployState.rankProfileRegistry(), summaries, summaryMap); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java new file mode 100644 index 00000000000..a1310d6c18e --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java @@ -0,0 +1,63 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.derived; + +import com.yahoo.config.application.api.FileRegistry; +import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.Schema; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Logger; + +/** + * ONNX models distributed as files. + * + * @author bratseth + */ +public class FileDistributedOnnxModels { + + private static final Logger log = Logger.getLogger(FileDistributedOnnxModels.class.getName()); + + private final Map<String, OnnxModel> models; + + public FileDistributedOnnxModels(FileRegistry fileRegistry, Collection<OnnxModel> models) { + Map<String, OnnxModel> distributableModels = new LinkedHashMap<>(); + for (var model : models) { + model.validate(); + model.register(fileRegistry); + distributableModels.put(model.getName(), model); + } + this.models = Collections.unmodifiableMap(distributableModels); + } + + public Map<String, OnnxModel> asMap() { return models; } + + public void getConfig(OnnxModelsConfig.Builder builder) { + for (OnnxModel model : models.values()) { + if ("".equals(model.getFileReference())) + log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way + else { + OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder(); + modelBuilder.dry_run_on_setup(true); + modelBuilder.name(model.getName()); + modelBuilder.fileref(model.getFileReference()); + model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source))); + model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as))); + if (model.getStatelessExecutionMode().isPresent()) + modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get()); + if (model.getStatelessInterOpThreads().isPresent()) + modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get()); + if (model.getStatelessIntraOpThreads().isPresent()) + modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get()); + + builder.model(modelBuilder); + } + } + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 65c0963856a..7384f98b121 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -5,9 +5,8 @@ import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; import com.yahoo.config.model.api.ModelContext; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.searchdefinition.OnnxModel; -import com.yahoo.searchdefinition.OnnxModels; import com.yahoo.searchdefinition.LargeRankExpressions; +import com.yahoo.searchdefinition.OnnxModel; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.vespa.config.search.RankProfilesConfig; @@ -23,12 +22,10 @@ import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; -import java.util.logging.Logger; /** * The derived rank profiles of a schema @@ -37,19 +34,17 @@ import java.util.logging.Logger; */ public class RankProfileList extends Derived implements RankProfilesConfig.Producer { - private static final Logger log = Logger.getLogger(RankProfileList.class.getName()); - private final Map<String, RawRankProfile> rankProfiles; private final FileDistributedConstants constants; private final LargeRankExpressions largeRankExpressions; - private final OnnxModels onnxModels; + private final FileDistributedOnnxModels onnxModels; public static final RankProfileList empty = new RankProfileList(); private RankProfileList() { constants = new FileDistributedConstants(null, List.of()); largeRankExpressions = new LargeRankExpressions(null); - onnxModels = new OnnxModels(null, Optional.empty()); + onnxModels = new FileDistributedOnnxModels(null, List.of()); rankProfiles = Map.of(); } @@ -60,46 +55,16 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ * @param attributeFields the attribute fields to create a ranking for */ public RankProfileList(Schema schema, - Map<Reference, RankProfile.Constant> constantsFromSchema, LargeRankExpressions largeRankExpressions, - OnnxModels onnxModels, AttributeFields attributeFields, DeployState deployState) { setName(schema == null ? "default" : schema.getName()); this.largeRankExpressions = largeRankExpressions; - this.onnxModels = onnxModels; // as ONNX models come from parsing rank expressions this.rankProfiles = deriveRankProfiles(schema, attributeFields, deployState); - this.constants = deriveFileDistributedConstants(schema, constantsFromSchema, rankProfiles.values(), deployState); - } - - private static FileDistributedConstants deriveFileDistributedConstants(Schema schema, - Map<Reference, RankProfile.Constant> constantsFromSchema, - Collection<RawRankProfile> rankProfiles, - DeployState deployState) { - Map<Reference, RankProfile.Constant> allFileConstants = new HashMap<>(); - addFileConstants(constantsFromSchema.values(), allFileConstants, schema != null ? schema.toString() : "[global]"); - for (var profile : rankProfiles) - addFileConstants(profile.compiled().constants().values(), allFileConstants, profile.toString()); - return new FileDistributedConstants(deployState.getFileRegistry(), allFileConstants.values()); + this.constants = deriveFileDistributedConstants(schema, rankProfiles.values(), deployState); + this.onnxModels = deriveFileDistributedOnnxModels(schema, rankProfiles.values(), deployState); } - private static void addFileConstants(Collection<RankProfile.Constant> source, - Map<Reference, RankProfile.Constant> destination, - String sourceName) { - for (var constant : source) { - if (constant.valuePath().isEmpty()) continue; - var existing = destination.get(constant.name()); - if ( existing != null && ! constant.equals(existing)) { - throw new IllegalArgumentException("Duplicate constants: " + sourceName + " have " + constant + - ", but we already have " + existing + - ": Value reference constants must be unique across all rank profiles/models"); - } - destination.put(constant.name(), constant); - } - } - - public FileDistributedConstants constants() { return constants; } - private boolean areDependenciesReady(RankProfile rank, RankProfileRegistry registry, Set<String> processedProfiles) { return rank.inheritedNames().isEmpty() || processedProfiles.containsAll(rank.inheritedNames()) || @@ -132,7 +97,6 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ rawRankProfiles.putAll(processRankProfiles(ready, deployState.getQueryProfiles().getRegistry(), deployState.getImportedModels(), - schema, attributeFields, deployState.getProperties(), deployState.getExecutor())); @@ -141,20 +105,15 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ return rawRankProfiles; } - private Map<String, RawRankProfile> processRankProfiles(List<RankProfile> ready, + private Map<String, RawRankProfile> processRankProfiles(List<RankProfile> profiles, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, - Schema schema, AttributeFields attributeFields, ModelContext.Properties deployProperties, ExecutorService executor) { Map<String, Future<RawRankProfile>> futureRawRankProfiles = new LinkedHashMap<>(); - for (RankProfile rank : ready) { - if (schema == null) { - onnxModels.add(rank.onnxModels()); - } - - futureRawRankProfiles.put(rank.name(), executor.submit(() -> new RawRankProfile(rank, largeRankExpressions, queryProfiles, importedModels, + for (RankProfile profile : profiles) { + futureRawRankProfiles.put(profile.name(), executor.submit(() -> new RawRankProfile(profile, largeRankExpressions, queryProfiles, importedModels, attributeFields, deployProperties))); } try { @@ -169,19 +128,63 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ } } - public OnnxModels getOnnxModels() { - return onnxModels; + private static FileDistributedConstants deriveFileDistributedConstants(Schema schema, + Collection<RawRankProfile> rankProfiles, + DeployState deployState) { + Map<Reference, RankProfile.Constant> allFileConstants = new HashMap<>(); + addFileConstants(schema != null ? schema.constants().values() : List.of(), + allFileConstants, + schema != null ? schema.toString() : "[global]"); + for (var profile : rankProfiles) + addFileConstants(profile.compiled().constants().values(), allFileConstants, profile.toString()); + return new FileDistributedConstants(deployState.getFileRegistry(), allFileConstants.values()); } - public Map<String, RawRankProfile> getRankProfiles() { - return rankProfiles; + private static void addFileConstants(Collection<RankProfile.Constant> source, + Map<Reference, RankProfile.Constant> destination, + String sourceName) { + for (var constant : source) { + if (constant.valuePath().isEmpty()) continue; + var existing = destination.get(constant.name()); + if ( existing != null && ! constant.equals(existing)) { + throw new IllegalArgumentException("Duplicate constants: " + sourceName + " have " + constant + + ", but we already have " + existing + + ": Value reference constants must be unique across all rank profiles/models"); + } + destination.put(constant.name(), constant); + } } - /** Returns the raw rank profile with the given name, or null if it is not present */ - public RawRankProfile getRankProfile(String name) { - return rankProfiles.get(name); + private static FileDistributedOnnxModels deriveFileDistributedOnnxModels(Schema schema, + Collection<RawRankProfile> rankProfiles, + DeployState deployState) { + Map<String, OnnxModel> allModels = new HashMap<>(); + addOnnxModels(schema != null ? schema.onnxModels().values() : List.of(), + allModels, + schema != null ? schema.toString() : "[global]"); + for (var profile : rankProfiles) + addOnnxModels(profile.compiled().onnxModels().values(), allModels, profile.toString()); + return new FileDistributedOnnxModels(deployState.getFileRegistry(), allModels.values()); } + private static void addOnnxModels(Collection<OnnxModel> source, + Map<String, OnnxModel> destination, + String sourceName) { + for (var model : source) { + var existing = destination.get(model.getName()); + if ( existing != null && ! model.equals(existing)) { + throw new IllegalArgumentException("Duplicate onnx model: " + sourceName + " have " + model + + ", but we already have " + existing + + ": Onnx models must be unique across all rank profiles/models"); + } + destination.put(model.getName(), model); + } + } + + public Map<String, RawRankProfile> getRankProfiles() { return rankProfiles; } + public FileDistributedConstants constants() { return constants; } + public FileDistributedOnnxModels getOnnxModels() { return onnxModels; } + @Override public String getDerivedName() { return "rank-profiles"; } @@ -201,25 +204,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ } public void getConfig(OnnxModelsConfig.Builder builder) { - for (OnnxModel model : onnxModels.asMap().values()) { - if ("".equals(model.getFileReference())) - log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way - else { - OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder(); - modelBuilder.dry_run_on_setup(true); - modelBuilder.name(model.getName()); - modelBuilder.fileref(model.getFileReference()); - model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source))); - model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as))); - if (model.getStatelessExecutionMode().isPresent()) - modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get()); - if (model.getStatelessInterOpThreads().isPresent()) - modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get()); - if (model.getStatelessIntraOpThreads().isPresent()) - modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get()); - - builder.model(modelBuilder); - } - } + onnxModels.getConfig(builder); } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index 2189d9d97db..dcd5019cf58 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -483,7 +483,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { private void deriveOnnxModelFunctionsAndFeatures(RankProfile rankProfile) { if (rankProfile.schema() == null) return; - if (rankProfile.schema().onnxModels().asMap().isEmpty()) return; + if (rankProfile.onnxModels().isEmpty()) return; replaceOnnxFunctionInputs(rankProfile); replaceImplicitOnnxConfigFeatures(summaryFeatures, rankProfile); replaceImplicitOnnxConfigFeatures(matchFeatures, rankProfile); @@ -492,7 +492,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { private void replaceOnnxFunctionInputs(RankProfile rankProfile) { Set<String> functionNames = rankProfile.getFunctions().keySet(); if (functionNames.isEmpty()) return; - for (OnnxModel onnxModel: rankProfile.schema().onnxModels().asMap().values()) { + for (OnnxModel onnxModel: rankProfile.onnxModels().values()) { for (Map.Entry<String, String> mapping : onnxModel.getInputMap().entrySet()) { String source = mapping.getValue(); if (functionNames.contains(source)) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedSchemas.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedSchemas.java index e56245d1332..4c32f11c20d 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedSchemas.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedSchemas.java @@ -207,12 +207,10 @@ public class ConvertParsedSchemas { if (documentsOnly) { return; // skip ranking-only content, not used for document type generation } - for (var constant : parsed.getConstants()) { + for (var constant : parsed.getConstants()) schema.add(constant); - } - for (var onnxModel : parsed.getOnnxModels()) { - schema.onnxModels().add(onnxModel); - } + for (var onnxModel : parsed.getOnnxModels()) + schema.add(onnxModel); rankProfileRegistry.add(new DefaultRankProfile(schema, rankProfileRegistry)); rankProfileRegistry.add(new UnrankedRankProfile(schema, rankProfileRegistry)); var rankConverter = new ConvertParsedRanking(rankProfileRegistry); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java index fdbde08d926..19fbc116558 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java @@ -86,10 +86,8 @@ public class OnnxModelConfigGenerator extends Processor { } OnnxModel onnxModel = schema.onnxModels().get(modelConfigName); - if (onnxModel == null) { - onnxModel = new OnnxModel(modelConfigName, path); - schema.onnxModels().add(onnxModel); - } + if (onnxModel == null) + schema.add(new OnnxModel(modelConfigName, path)); } } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java index 4153cca4b5b..83e0c367292 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java @@ -4,6 +4,7 @@ package com.yahoo.searchdefinition.processing; import com.yahoo.config.application.api.DeployLogger; import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Schema; import com.yahoo.vespa.model.container.search.QueryProfiles; @@ -28,9 +29,11 @@ public class OnnxModelTypeResolver extends Processor { @Override public void process(boolean validate, boolean documentsOnly) { if (documentsOnly) return; - for (OnnxModel onnxModel : schema.onnxModels().asMap().values()) { - OnnxModelInfo onnxModelInfo = OnnxModelInfo.load(onnxModel.getFileName(), schema.applicationPackage()); - onnxModel.setModelInfo(onnxModelInfo); + for (OnnxModel onnxModel : schema.declaredOnnxModels().values()) + onnxModel.setModelInfo(OnnxModelInfo.load(onnxModel.getFileName(), schema.applicationPackage())); + for (RankProfile profile : rankProfileRegistry.rankProfilesOf(schema)) { + for (OnnxModel onnxModel : profile.declaredOnnxModels().values()) + onnxModel.setModelInfo(OnnxModelInfo.load(onnxModel.getFileName(), schema.applicationPackage())); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index 03fc7592b40..149326130be 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -32,7 +32,7 @@ import com.yahoo.container.QrConfig; import com.yahoo.path.Path; import com.yahoo.searchdefinition.LargeRankExpressions; import com.yahoo.searchdefinition.OnnxModel; -import com.yahoo.searchdefinition.OnnxModels; +import com.yahoo.searchdefinition.derived.FileDistributedOnnxModels; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.derived.AttributeFields; @@ -175,11 +175,9 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri VespaModelBuilder builder = new VespaDomBuilder(); root = builder.getRoot(VespaModel.ROOT_CONFIGID, deployState, this); - createGlobalRankProfiles(deployState, deployState.getFileRegistry()); + createGlobalRankProfiles(deployState); rankProfileList = new RankProfileList(null, // null search -> global - Map.of(), new LargeRankExpressions(deployState.getFileRegistry()), - new OnnxModels(deployState.getFileRegistry(), Optional.empty()), AttributeFields.empty, deployState); @@ -274,7 +272,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri * Creates a rank profile not attached to any search definition, for each imported model in the application package, * and adds it to the given rank profile registry. */ - private void createGlobalRankProfiles(DeployState deployState, FileRegistry fileRegistry) { + private void createGlobalRankProfiles(DeployState deployState) { var importedModels = deployState.getImportedModels().all(); DeployLogger deployLogger = deployState.getDeployLogger(); RankProfileRegistry rankProfileRegistry = deployState.rankProfileRegistry(); @@ -283,8 +281,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri if ( ! importedModels.isEmpty()) { // models/ directory is available for (ImportedMlModel model : importedModels) { // Due to automatic naming not guaranteeing unique names, there must be a 1-1 between OnnxModels and global RankProfiles. - OnnxModels onnxModels = onnxModelInfoFromSource(model, fileRegistry); - RankProfile profile = new RankProfile(model.name(), applicationPackage, deployLogger, rankProfileRegistry, onnxModels); + RankProfile profile = new RankProfile(model.name(), applicationPackage, deployLogger, rankProfileRegistry); + addOnnxModelInfoFromSource(model, profile); rankProfileRegistry.add(profile); futureModels.add(deployState.getExecutor().submit(() -> { ConvertedModel convertedModel = ConvertedModel.fromSource(applicationPackage, new ModelName(model.name()), @@ -300,8 +298,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri String modelName = generatedModelDir.getPath().last(); if (modelName.contains(".")) continue; // Name space: Not a global profile // Due to automatic naming not guaranteeing unique names, there must be a 1-1 between OnnxModels and global RankProfiles. - OnnxModels onnxModels = onnxModelInfoFromStore(modelName, fileRegistry); - RankProfile profile = new RankProfile(modelName, applicationPackage, deployLogger, rankProfileRegistry, onnxModels); + RankProfile profile = new RankProfile(modelName, applicationPackage, deployLogger, rankProfileRegistry); + addOnnxModelInfoFromStore(modelName, profile); rankProfileRegistry.add(profile); futureModels.add(deployState.getExecutor().submit(() -> { ConvertedModel convertedModel = ConvertedModel.fromStore(applicationPackage, new ModelName(modelName), modelName, profile); @@ -320,27 +318,23 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri new Processing().processRankProfiles(deployLogger, rankProfileRegistry, queryProfiles, true, false); } - private OnnxModels onnxModelInfoFromSource(ImportedMlModel model, FileRegistry fileRegistry) { - OnnxModels onnxModels = new OnnxModels(fileRegistry, Optional.empty()); + private void addOnnxModelInfoFromSource(ImportedMlModel model, RankProfile profile) { if (model.modelType() == ImportedMlModel.ModelType.ONNX) { String path = model.source(); String applicationPath = this.applicationPackage.getFileReference(Path.fromString("")).toString(); if (path.startsWith(applicationPath)) { path = path.substring(applicationPath.length() + 1); } - loadOnnxModelInfo(onnxModels, model.name(), path); + addOnnxModelInfo(model.name(), path, profile); } - return onnxModels; } - private OnnxModels onnxModelInfoFromStore(String modelName, FileRegistry fileRegistry) { + private void addOnnxModelInfoFromStore(String modelName, RankProfile profile) { String path = ApplicationPackage.MODELS_DIR.append(modelName + ".onnx").toString(); - OnnxModels onnxModels = new OnnxModels(fileRegistry, Optional.empty()); - loadOnnxModelInfo(onnxModels, modelName, path); - return onnxModels; + addOnnxModelInfo(modelName, path, profile); } - private void loadOnnxModelInfo(OnnxModels onnxModels, String name, String path) { + private void addOnnxModelInfo(String name, String path, RankProfile profile) { boolean modelExists = OnnxModelInfo.modelExists(path, this.applicationPackage); if ( ! modelExists) { path = ApplicationPackage.MODELS_DIR.append(path).toString(); @@ -351,7 +345,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri if (onnxModelInfo.getModelPath() != null) { OnnxModel onnxModel = new OnnxModel(name, onnxModelInfo.getModelPath()); onnxModel.setModelInfo(onnxModelInfo); - onnxModels.add(onnxModel); + profile.add(onnxModel); } } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantValidator.java index b033d7a1e3b..fb70f2b769c 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantValidator.java @@ -25,7 +25,7 @@ public class ConstantValidator extends Validator { public void validate(VespaModel model, DeployState deployState) { var exceptionMessageCollector = new ExceptionMessageCollector("Invalid constant tensor file(s):"); for (Schema schema : deployState.getSchemas()) { - for (var constant : schema.constants().values()) + for (var constant : schema.declaredConstants().values()) validate(constant, deployState.getApplicationPackage(), exceptionMessageCollector); for (var profile : deployState.rankProfileRegistry().rankProfilesOf(schema)) { for (var constant : profile.declaredConstants().values()) diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java index b9b7f122d63..3492ccf0b21 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java @@ -152,7 +152,7 @@ public class RankSetupValidator extends Validator { List<String> config = new ArrayList<>(); // Assist verify-ranksetup in finding the actual ONNX model files - writeExtraVerifyRankSetupConfig(config, db.getDerivedConfiguration().getSchema().onnxModels().asMap().values()); + writeExtraVerifyRankSetupConfig(config, db.getDerivedConfiguration().getRankProfileList().getOnnxModels().asMap().values()); writeExtraVerifyRankSetupConfig(config, db.getDerivedConfiguration().getSchema().rankExpressionFiles().expressions()); config.sort(String::compareTo); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index bb84f809fc4..17640417b3f 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -598,7 +598,7 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { Element onnxElement = XML.getChild(modelEvaluationElement, "onnx"); Element modelsElement = XML.getChild(onnxElement, "models"); for (Element modelElement : XML.getChildren(modelsElement, "model") ) { - OnnxModel onnxModel = profiles.getOnnxModels().get(modelElement.getAttribute("name")); + OnnxModel onnxModel = profiles.getOnnxModels().asMap().get(modelElement.getAttribute("name")); if (onnxModel == null) continue; // Skip if model is not found onnxModel.setStatelessExecutionMode(getStringValue(modelElement, "execution-mode", null)); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/SchemaTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/SchemaTestCase.java index 65f4dab3650..cb57746d82f 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/SchemaTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/SchemaTestCase.java @@ -196,8 +196,8 @@ public class SchemaTestCase { assertTrue(child1profile.constants().containsKey(FeatureNames.asConstantFeature("parent_constant"))); assertNotNull(child1.onnxModels().get("parent_model")); assertNotNull(child1.onnxModels().get("child1_model")); - assertTrue(child1.onnxModels().asMap().containsKey("parent_model")); - assertTrue(child1.onnxModels().asMap().containsKey("child1_model")); + assertTrue(child1.onnxModels().containsKey("parent_model")); + assertTrue(child1.onnxModels().containsKey("child1_model")); assertNotNull(child1.getSummary("parent_summary")); assertNotNull(child1.getSummary("child1_summary")); assertEquals("parent_summary", child1.getSummary("child1_summary").inherited().get().getName()); @@ -231,8 +231,8 @@ public class SchemaTestCase { assertTrue(child2.constants().containsKey(FeatureNames.asConstantFeature("child2_constant"))); assertNotNull(child2.onnxModels().get("parent_model")); assertNotNull(child2.onnxModels().get("child2_model")); - assertTrue(child2.onnxModels().asMap().containsKey("parent_model")); - assertTrue(child2.onnxModels().asMap().containsKey("child2_model")); + assertTrue(child2.onnxModels().containsKey("parent_model")); + assertTrue(child2.onnxModels().containsKey("child2_model")); assertNotNull(child2.getSummary("parent_summary")); assertNotNull(child2.getSummary("child2_summary")); assertEquals("parent_summary", child2.getSummary("child2_summary").inherited().get().getName()); @@ -430,7 +430,7 @@ public class SchemaTestCase { assertNotNull(schema.constants().get(FeatureNames.asConstantFeature("parent_constant"))); assertTrue(schema.constants().containsKey(FeatureNames.asConstantFeature("parent_constant"))); assertNotNull(schema.onnxModels().get("parent_model")); - assertTrue(schema.onnxModels().asMap().containsKey("parent_model")); + assertTrue(schema.onnxModels().containsKey("parent_model")); assertNotNull(schema.getSummary("parent_summary")); assertTrue(schema.getSummaries().containsKey("parent_summary")); assertNotNull(schema.getSummaryField("pf1")); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/GeminiTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/GeminiTestCase.java index 07e6fbf7b1b..207792ffe06 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/GeminiTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/GeminiTestCase.java @@ -19,7 +19,7 @@ public class GeminiTestCase extends AbstractExportingTestCase { @Test public void testRanking2() throws IOException, ParseException { DerivedConfiguration c = assertCorrectDeriving("gemini2"); - RawRankProfile p = c.getRankProfileList().getRankProfile("test"); + RawRankProfile p = c.getRankProfileList().getRankProfiles().get("test"); Map<String, String> ranking = removePartKeySuffixes(asMap(p.configProperties())); assertEquals("attribute(right)", resolve(lookup("toplevel", ranking), ranking)); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java index e0427d93ee4..4446f01aa95 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java @@ -68,7 +68,7 @@ public class VespaMlModelTestCase { private String rankConfigOf(String rankProfileName, VespaModel model) { StringBuilder b = new StringBuilder(); - RawRankProfile profile = model.rankProfileList().getRankProfile(rankProfileName); + RawRankProfile profile = model.rankProfileList().getRankProfiles().get(rankProfileName); for (var property : profile.configProperties()) b.append(property.getFirst()).append(" : ").append(property.getSecond()).append("\n"); return b.toString(); |