diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition')
11 files changed, 197 insertions, 178 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSchema.java b/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSchema.java index 1c643292a05..3719313179f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSchema.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSchema.java @@ -4,6 +4,7 @@ package com.yahoo.searchdefinition; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.api.ModelContext; +import com.yahoo.searchdefinition.derived.FileDistributedOnnxModels; import com.yahoo.searchdefinition.document.ImmutableSDField; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; @@ -37,7 +38,7 @@ public interface ImmutableSchema { ModelContext.Properties getDeployProperties(); Map<Reference, RankProfile.Constant> constants(); LargeRankExpressions rankExpressionFiles(); - OnnxModels onnxModels(); + Map<String, OnnxModel> onnxModels(); Stream<ImmutableSDField> allImportedFields(); SDDocumentType getDocument(); ImmutableSDField getField(String name); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java deleted file mode 100644 index c9c12100552..00000000000 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchdefinition; - -import com.yahoo.config.application.api.FileRegistry; - -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; - -/** - * ONNX models tied to a search definition or global. - * - * @author lesters - */ -public class OnnxModels { - - private final FileRegistry fileRegistry; - - /** The schema this belongs to, or empty if it is global */ - private final Optional<Schema> owner; - - private final Map<String, OnnxModel> models = new HashMap<>(); - - public OnnxModels(FileRegistry fileRegistry, Optional<Schema> owner) { - this.fileRegistry = fileRegistry; - this.owner = owner; - } - - public void add(OnnxModel model) { - model.validate(); - model.register(fileRegistry); - String name = model.getName(); - models.put(name, model); - } - - public void add(Map<String, OnnxModel> models) { - models.values().forEach(this::add); - } - - public OnnxModel get(String name) { - var model = models.get(name); - if (model != null) return model; - if (owner.isPresent() && owner.get().inherited().isPresent()) - return owner.get().inherited().get().onnxModels().get(name); - return null; - } - - public boolean has(String name) { - boolean has = models.containsKey(name); - if (has) return true; - if (owner.isPresent() && owner.get().inherited().isPresent()) - return owner.get().inherited().get().onnxModels().has(name); - return false; - } - - public Map<String, OnnxModel> asMap() { - // Shortcuts - if (owner.isEmpty() || owner.get().inherited().isEmpty()) return Collections.unmodifiableMap(models); - if (models.isEmpty()) return owner.get().inherited().get().onnxModels().asMap(); - - var allModels = new HashMap<>(owner.get().inherited().get().onnxModels().asMap()); - allModels.putAll(models); - return Collections.unmodifiableMap(allModels); - } - -} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index fd7e7e57f00..07f3048af04 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -118,6 +118,8 @@ public class RankProfile implements Cloneable { private Map<Reference, Constant> constants = new HashMap<>(); + private Map<String, OnnxModel> onnxModels = new HashMap<>(); + private Set<String> filterFields = new HashSet<>(); private final RankProfileRegistry rankProfileRegistry; @@ -128,8 +130,6 @@ public class RankProfile implements Cloneable { private Boolean strict; - /** Global onnx models not tied to a schema */ - private final OnnxModels onnxModels; private final ApplicationPackage applicationPackage; private final DeployLogger deployLogger; @@ -144,7 +144,6 @@ public class RankProfile implements Cloneable { public RankProfile(String name, Schema schema, RankProfileRegistry rankProfileRegistry) { this.name = Objects.requireNonNull(name, "name cannot be null"); this.schema = Objects.requireNonNull(schema, "schema cannot be null"); - this.onnxModels = null; this.rankProfileRegistry = rankProfileRegistry; this.applicationPackage = schema.applicationPackage(); this.deployLogger = schema.getDeployLogger(); @@ -156,11 +155,10 @@ public class RankProfile implements Cloneable { * @param name the name of the new profile */ public RankProfile(String name, ApplicationPackage applicationPackage, DeployLogger deployLogger, - RankProfileRegistry rankProfileRegistry, OnnxModels onnxModels) { + RankProfileRegistry rankProfileRegistry) { this.name = Objects.requireNonNull(name, "name cannot be null"); this.schema = null; this.rankProfileRegistry = rankProfileRegistry; - this.onnxModels = onnxModels; this.applicationPackage = applicationPackage; this.deployLogger = deployLogger; } @@ -175,10 +173,6 @@ public class RankProfile implements Cloneable { return applicationPackage; } - public Map<String, OnnxModel> onnxModels() { - return schema != null ? schema.onnxModels().asMap() : onnxModels.asMap(); - } - private Stream<ImmutableSDField> allFields() { if (schema == null) return Stream.empty(); if (allFieldsList == null) { @@ -411,6 +405,9 @@ public class RankProfile implements Cloneable { constants.put(constant.name(), constant); } + /** Returns an unmodifiable view of the constants declared in this */ + public Map<Reference, Constant> declaredConstants() { return Collections.unmodifiableMap(constants); } + /** Returns an unmodifiable view of the constants available in this */ public Map<Reference, Constant> constants() { Map<Reference, Constant> allConstants = new HashMap<>(); @@ -430,8 +427,31 @@ public class RankProfile implements Cloneable { return allConstants; } - /** Returns an unmodifiable view of the constants declared in this */ - public Map<Reference, Constant> declaredConstants() { return Collections.unmodifiableMap(constants); } + public void add(OnnxModel model) { + onnxModels.put(model.getName(), model); + } + + /** Returns an unmodifiable map of the onnx models declared in this. */ + public Map<String, OnnxModel> declaredOnnxModels() { return onnxModels; } + + /** Returns an unmodifiable map of the onnx models available in this. */ + public Map<String, OnnxModel> onnxModels() { + Map<String, OnnxModel> allModels = new HashMap<>(); + for (var inheritedProfile : inherited()) { + for (var model : inheritedProfile.onnxModels().values()) { + if (allModels.containsKey(model.getName())) + throw new IllegalArgumentException(model + "' is present in " + + inheritedProfile + " inherited by " + + this + ", but is also present in another profile inherited by it"); + allModels.put(model.getName(), model); + } + } + + if (schema != null) + allModels.putAll(schema.onnxModels()); + allModels.putAll(onnxModels); + return allModels; + } public void addAttributeType(String attributeName, String attributeType) { attributeTypes.addType(attributeName, attributeType); @@ -1064,10 +1084,8 @@ public class RankProfile implements Cloneable { } // Add output types for ONNX models - for (Map.Entry<String, OnnxModel> entry : onnxModels().entrySet()) { - String modelName = entry.getKey(); - OnnxModel model = entry.getValue(); - Arguments args = new Arguments(new ReferenceNode(modelName)); + for (var model : onnxModels().values()) { + Arguments args = new Arguments(new ReferenceNode(model.getName())); Map<String, TensorType> inputTypes = resolveOnnxInputTypes(model, context); TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), inputTypes); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Schema.java b/config-model/src/main/java/com/yahoo/searchdefinition/Schema.java index 147fee05820..8139d46cc0a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/Schema.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/Schema.java @@ -9,6 +9,7 @@ import com.yahoo.config.model.application.provider.BaseDeployLogger; import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.document.DataTypeName; import com.yahoo.document.Field; +import com.yahoo.searchdefinition.derived.FileDistributedOnnxModels; import com.yahoo.searchdefinition.derived.SummaryClass; import com.yahoo.searchdefinition.document.Attribute; import com.yahoo.searchdefinition.document.ImmutableSDField; @@ -27,6 +28,7 @@ import java.io.Reader; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; @@ -90,7 +92,8 @@ public class Schema implements ImmutableSchema { // TODO: Remove on Vespa 9: Should always be in a rank profile private final Map<Reference, RankProfile.Constant> constants = new LinkedHashMap<>(); - private final OnnxModels onnxModels; + // TODO: Remove on Vespa 9: Should always be in a rank profile + private final Map<String, OnnxModel> onnxModels = new LinkedHashMap<>(); /** All imported fields of this (and parent schemas) */ // TODO: Use empty, not optional @@ -150,7 +153,6 @@ public class Schema implements ImmutableSchema { this.properties = properties; this.documentsOnly = documentsOnly; largeRankExpressions = new LargeRankExpressions(fileRegistry); - onnxModels = new OnnxModels(fileRegistry, Optional.of(this)); } /** @@ -230,6 +232,10 @@ public class Schema implements ImmutableSchema { constants.put(constant.name(), constant); } + /** Returns an unmodifiable map of the constants declared in this. */ + public Map<Reference, RankProfile.Constant> declaredConstants() { return constants; } + + /** Returns an unmodifiable map of the constants available in this. */ @Override public Map<Reference, RankProfile.Constant> constants() { if (inherited().isEmpty()) return Collections.unmodifiableMap(constants); @@ -240,8 +246,23 @@ public class Schema implements ImmutableSchema { return allConstants; } + public void add(OnnxModel model) { + onnxModels.put(model.getName(), model); + } + + /** Returns an unmodifiable map of the onnx models declared in this. */ + public Map<String, OnnxModel> declaredOnnxModels() { return onnxModels; } + + /** Returns an unmodifiable map of the onnx models available in this. */ @Override - public OnnxModels onnxModels() { return onnxModels; } + public Map<String, OnnxModel> onnxModels() { + if (inherited().isEmpty()) return Collections.unmodifiableMap(onnxModels); + if (onnxModels.isEmpty()) return inherited().get().onnxModels(); + + Map<String, OnnxModel> allModels = new LinkedHashMap<>(inherited().get().onnxModels()); + allModels.putAll(onnxModels); + return allModels; + } public Optional<TemporaryImportedFields> temporaryImportedFields() { return temporaryImportedFields; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java index 50029613b2e..c7b9b94a4b2 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java @@ -82,8 +82,7 @@ public class DerivedConfiguration implements AttributesConfig.Producer { summaries = new Summaries(schema, deployState.getDeployLogger(), deployState.getProperties().featureFlags()); summaryMap = new SummaryMap(schema); juniperrc = new Juniperrc(schema); - rankProfileList = new RankProfileList(schema, schema.constants(), schema.rankExpressionFiles(), - schema.onnxModels(), attributeFields, deployState); + rankProfileList = new RankProfileList(schema, schema.rankExpressionFiles(), attributeFields, deployState); indexingScript = new IndexingScript(schema); indexInfo = new IndexInfo(schema); schemaInfo = new SchemaInfo(schema, deployState.rankProfileRegistry(), summaries, summaryMap); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java new file mode 100644 index 00000000000..a1310d6c18e --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java @@ -0,0 +1,63 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.derived; + +import com.yahoo.config.application.api.FileRegistry; +import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.Schema; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Logger; + +/** + * ONNX models distributed as files. + * + * @author bratseth + */ +public class FileDistributedOnnxModels { + + private static final Logger log = Logger.getLogger(FileDistributedOnnxModels.class.getName()); + + private final Map<String, OnnxModel> models; + + public FileDistributedOnnxModels(FileRegistry fileRegistry, Collection<OnnxModel> models) { + Map<String, OnnxModel> distributableModels = new LinkedHashMap<>(); + for (var model : models) { + model.validate(); + model.register(fileRegistry); + distributableModels.put(model.getName(), model); + } + this.models = Collections.unmodifiableMap(distributableModels); + } + + public Map<String, OnnxModel> asMap() { return models; } + + public void getConfig(OnnxModelsConfig.Builder builder) { + for (OnnxModel model : models.values()) { + if ("".equals(model.getFileReference())) + log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way + else { + OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder(); + modelBuilder.dry_run_on_setup(true); + modelBuilder.name(model.getName()); + modelBuilder.fileref(model.getFileReference()); + model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source))); + model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as))); + if (model.getStatelessExecutionMode().isPresent()) + modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get()); + if (model.getStatelessInterOpThreads().isPresent()) + modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get()); + if (model.getStatelessIntraOpThreads().isPresent()) + modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get()); + + builder.model(modelBuilder); + } + } + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 65c0963856a..7384f98b121 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -5,9 +5,8 @@ import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; import com.yahoo.config.model.api.ModelContext; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.searchdefinition.OnnxModel; -import com.yahoo.searchdefinition.OnnxModels; import com.yahoo.searchdefinition.LargeRankExpressions; +import com.yahoo.searchdefinition.OnnxModel; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.vespa.config.search.RankProfilesConfig; @@ -23,12 +22,10 @@ import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; -import java.util.logging.Logger; /** * The derived rank profiles of a schema @@ -37,19 +34,17 @@ import java.util.logging.Logger; */ public class RankProfileList extends Derived implements RankProfilesConfig.Producer { - private static final Logger log = Logger.getLogger(RankProfileList.class.getName()); - private final Map<String, RawRankProfile> rankProfiles; private final FileDistributedConstants constants; private final LargeRankExpressions largeRankExpressions; - private final OnnxModels onnxModels; + private final FileDistributedOnnxModels onnxModels; public static final RankProfileList empty = new RankProfileList(); private RankProfileList() { constants = new FileDistributedConstants(null, List.of()); largeRankExpressions = new LargeRankExpressions(null); - onnxModels = new OnnxModels(null, Optional.empty()); + onnxModels = new FileDistributedOnnxModels(null, List.of()); rankProfiles = Map.of(); } @@ -60,46 +55,16 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ * @param attributeFields the attribute fields to create a ranking for */ public RankProfileList(Schema schema, - Map<Reference, RankProfile.Constant> constantsFromSchema, LargeRankExpressions largeRankExpressions, - OnnxModels onnxModels, AttributeFields attributeFields, DeployState deployState) { setName(schema == null ? "default" : schema.getName()); this.largeRankExpressions = largeRankExpressions; - this.onnxModels = onnxModels; // as ONNX models come from parsing rank expressions this.rankProfiles = deriveRankProfiles(schema, attributeFields, deployState); - this.constants = deriveFileDistributedConstants(schema, constantsFromSchema, rankProfiles.values(), deployState); - } - - private static FileDistributedConstants deriveFileDistributedConstants(Schema schema, - Map<Reference, RankProfile.Constant> constantsFromSchema, - Collection<RawRankProfile> rankProfiles, - DeployState deployState) { - Map<Reference, RankProfile.Constant> allFileConstants = new HashMap<>(); - addFileConstants(constantsFromSchema.values(), allFileConstants, schema != null ? schema.toString() : "[global]"); - for (var profile : rankProfiles) - addFileConstants(profile.compiled().constants().values(), allFileConstants, profile.toString()); - return new FileDistributedConstants(deployState.getFileRegistry(), allFileConstants.values()); + this.constants = deriveFileDistributedConstants(schema, rankProfiles.values(), deployState); + this.onnxModels = deriveFileDistributedOnnxModels(schema, rankProfiles.values(), deployState); } - private static void addFileConstants(Collection<RankProfile.Constant> source, - Map<Reference, RankProfile.Constant> destination, - String sourceName) { - for (var constant : source) { - if (constant.valuePath().isEmpty()) continue; - var existing = destination.get(constant.name()); - if ( existing != null && ! constant.equals(existing)) { - throw new IllegalArgumentException("Duplicate constants: " + sourceName + " have " + constant + - ", but we already have " + existing + - ": Value reference constants must be unique across all rank profiles/models"); - } - destination.put(constant.name(), constant); - } - } - - public FileDistributedConstants constants() { return constants; } - private boolean areDependenciesReady(RankProfile rank, RankProfileRegistry registry, Set<String> processedProfiles) { return rank.inheritedNames().isEmpty() || processedProfiles.containsAll(rank.inheritedNames()) || @@ -132,7 +97,6 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ rawRankProfiles.putAll(processRankProfiles(ready, deployState.getQueryProfiles().getRegistry(), deployState.getImportedModels(), - schema, attributeFields, deployState.getProperties(), deployState.getExecutor())); @@ -141,20 +105,15 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ return rawRankProfiles; } - private Map<String, RawRankProfile> processRankProfiles(List<RankProfile> ready, + private Map<String, RawRankProfile> processRankProfiles(List<RankProfile> profiles, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, - Schema schema, AttributeFields attributeFields, ModelContext.Properties deployProperties, ExecutorService executor) { Map<String, Future<RawRankProfile>> futureRawRankProfiles = new LinkedHashMap<>(); - for (RankProfile rank : ready) { - if (schema == null) { - onnxModels.add(rank.onnxModels()); - } - - futureRawRankProfiles.put(rank.name(), executor.submit(() -> new RawRankProfile(rank, largeRankExpressions, queryProfiles, importedModels, + for (RankProfile profile : profiles) { + futureRawRankProfiles.put(profile.name(), executor.submit(() -> new RawRankProfile(profile, largeRankExpressions, queryProfiles, importedModels, attributeFields, deployProperties))); } try { @@ -169,19 +128,63 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ } } - public OnnxModels getOnnxModels() { - return onnxModels; + private static FileDistributedConstants deriveFileDistributedConstants(Schema schema, + Collection<RawRankProfile> rankProfiles, + DeployState deployState) { + Map<Reference, RankProfile.Constant> allFileConstants = new HashMap<>(); + addFileConstants(schema != null ? schema.constants().values() : List.of(), + allFileConstants, + schema != null ? schema.toString() : "[global]"); + for (var profile : rankProfiles) + addFileConstants(profile.compiled().constants().values(), allFileConstants, profile.toString()); + return new FileDistributedConstants(deployState.getFileRegistry(), allFileConstants.values()); } - public Map<String, RawRankProfile> getRankProfiles() { - return rankProfiles; + private static void addFileConstants(Collection<RankProfile.Constant> source, + Map<Reference, RankProfile.Constant> destination, + String sourceName) { + for (var constant : source) { + if (constant.valuePath().isEmpty()) continue; + var existing = destination.get(constant.name()); + if ( existing != null && ! constant.equals(existing)) { + throw new IllegalArgumentException("Duplicate constants: " + sourceName + " have " + constant + + ", but we already have " + existing + + ": Value reference constants must be unique across all rank profiles/models"); + } + destination.put(constant.name(), constant); + } } - /** Returns the raw rank profile with the given name, or null if it is not present */ - public RawRankProfile getRankProfile(String name) { - return rankProfiles.get(name); + private static FileDistributedOnnxModels deriveFileDistributedOnnxModels(Schema schema, + Collection<RawRankProfile> rankProfiles, + DeployState deployState) { + Map<String, OnnxModel> allModels = new HashMap<>(); + addOnnxModels(schema != null ? schema.onnxModels().values() : List.of(), + allModels, + schema != null ? schema.toString() : "[global]"); + for (var profile : rankProfiles) + addOnnxModels(profile.compiled().onnxModels().values(), allModels, profile.toString()); + return new FileDistributedOnnxModels(deployState.getFileRegistry(), allModels.values()); } + private static void addOnnxModels(Collection<OnnxModel> source, + Map<String, OnnxModel> destination, + String sourceName) { + for (var model : source) { + var existing = destination.get(model.getName()); + if ( existing != null && ! model.equals(existing)) { + throw new IllegalArgumentException("Duplicate onnx model: " + sourceName + " have " + model + + ", but we already have " + existing + + ": Onnx models must be unique across all rank profiles/models"); + } + destination.put(model.getName(), model); + } + } + + public Map<String, RawRankProfile> getRankProfiles() { return rankProfiles; } + public FileDistributedConstants constants() { return constants; } + public FileDistributedOnnxModels getOnnxModels() { return onnxModels; } + @Override public String getDerivedName() { return "rank-profiles"; } @@ -201,25 +204,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ } public void getConfig(OnnxModelsConfig.Builder builder) { - for (OnnxModel model : onnxModels.asMap().values()) { - if ("".equals(model.getFileReference())) - log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way - else { - OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder(); - modelBuilder.dry_run_on_setup(true); - modelBuilder.name(model.getName()); - modelBuilder.fileref(model.getFileReference()); - model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source))); - model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as))); - if (model.getStatelessExecutionMode().isPresent()) - modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get()); - if (model.getStatelessInterOpThreads().isPresent()) - modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get()); - if (model.getStatelessIntraOpThreads().isPresent()) - modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get()); - - builder.model(modelBuilder); - } - } + onnxModels.getConfig(builder); } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index 2189d9d97db..dcd5019cf58 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -483,7 +483,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { private void deriveOnnxModelFunctionsAndFeatures(RankProfile rankProfile) { if (rankProfile.schema() == null) return; - if (rankProfile.schema().onnxModels().asMap().isEmpty()) return; + if (rankProfile.onnxModels().isEmpty()) return; replaceOnnxFunctionInputs(rankProfile); replaceImplicitOnnxConfigFeatures(summaryFeatures, rankProfile); replaceImplicitOnnxConfigFeatures(matchFeatures, rankProfile); @@ -492,7 +492,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { private void replaceOnnxFunctionInputs(RankProfile rankProfile) { Set<String> functionNames = rankProfile.getFunctions().keySet(); if (functionNames.isEmpty()) return; - for (OnnxModel onnxModel: rankProfile.schema().onnxModels().asMap().values()) { + for (OnnxModel onnxModel: rankProfile.onnxModels().values()) { for (Map.Entry<String, String> mapping : onnxModel.getInputMap().entrySet()) { String source = mapping.getValue(); if (functionNames.contains(source)) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedSchemas.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedSchemas.java index e56245d1332..4c32f11c20d 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedSchemas.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedSchemas.java @@ -207,12 +207,10 @@ public class ConvertParsedSchemas { if (documentsOnly) { return; // skip ranking-only content, not used for document type generation } - for (var constant : parsed.getConstants()) { + for (var constant : parsed.getConstants()) schema.add(constant); - } - for (var onnxModel : parsed.getOnnxModels()) { - schema.onnxModels().add(onnxModel); - } + for (var onnxModel : parsed.getOnnxModels()) + schema.add(onnxModel); rankProfileRegistry.add(new DefaultRankProfile(schema, rankProfileRegistry)); rankProfileRegistry.add(new UnrankedRankProfile(schema, rankProfileRegistry)); var rankConverter = new ConvertParsedRanking(rankProfileRegistry); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java index fdbde08d926..19fbc116558 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java @@ -86,10 +86,8 @@ public class OnnxModelConfigGenerator extends Processor { } OnnxModel onnxModel = schema.onnxModels().get(modelConfigName); - if (onnxModel == null) { - onnxModel = new OnnxModel(modelConfigName, path); - schema.onnxModels().add(onnxModel); - } + if (onnxModel == null) + schema.add(new OnnxModel(modelConfigName, path)); } } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java index 4153cca4b5b..83e0c367292 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java @@ -4,6 +4,7 @@ package com.yahoo.searchdefinition.processing; import com.yahoo.config.application.api.DeployLogger; import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Schema; import com.yahoo.vespa.model.container.search.QueryProfiles; @@ -28,9 +29,11 @@ public class OnnxModelTypeResolver extends Processor { @Override public void process(boolean validate, boolean documentsOnly) { if (documentsOnly) return; - for (OnnxModel onnxModel : schema.onnxModels().asMap().values()) { - OnnxModelInfo onnxModelInfo = OnnxModelInfo.load(onnxModel.getFileName(), schema.applicationPackage()); - onnxModel.setModelInfo(onnxModelInfo); + for (OnnxModel onnxModel : schema.declaredOnnxModels().values()) + onnxModel.setModelInfo(OnnxModelInfo.load(onnxModel.getFileName(), schema.applicationPackage())); + for (RankProfile profile : rankProfileRegistry.rankProfilesOf(schema)) { + for (OnnxModel onnxModel : profile.declaredOnnxModels().values()) + onnxModel.setModelInfo(OnnxModelInfo.load(onnxModel.getFileName(), schema.applicationPackage())); } } |