summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSchema.java3
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java67
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java48
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/Schema.java27
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java3
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java63
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java137
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedSchemas.java8
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java6
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java9
11 files changed, 197 insertions, 178 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSchema.java b/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSchema.java
index 1c643292a05..3719313179f 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSchema.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSchema.java
@@ -4,6 +4,7 @@ package com.yahoo.searchdefinition;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.model.api.ModelContext;
+import com.yahoo.searchdefinition.derived.FileDistributedOnnxModels;
import com.yahoo.searchdefinition.document.ImmutableSDField;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.document.SDField;
@@ -37,7 +38,7 @@ public interface ImmutableSchema {
ModelContext.Properties getDeployProperties();
Map<Reference, RankProfile.Constant> constants();
LargeRankExpressions rankExpressionFiles();
- OnnxModels onnxModels();
+ Map<String, OnnxModel> onnxModels();
Stream<ImmutableSDField> allImportedFields();
SDDocumentType getDocument();
ImmutableSDField getField(String name);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
deleted file mode 100644
index c9c12100552..00000000000
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
+++ /dev/null
@@ -1,67 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchdefinition;
-
-import com.yahoo.config.application.api.FileRegistry;
-
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Optional;
-
-/**
- * ONNX models tied to a search definition or global.
- *
- * @author lesters
- */
-public class OnnxModels {
-
- private final FileRegistry fileRegistry;
-
- /** The schema this belongs to, or empty if it is global */
- private final Optional<Schema> owner;
-
- private final Map<String, OnnxModel> models = new HashMap<>();
-
- public OnnxModels(FileRegistry fileRegistry, Optional<Schema> owner) {
- this.fileRegistry = fileRegistry;
- this.owner = owner;
- }
-
- public void add(OnnxModel model) {
- model.validate();
- model.register(fileRegistry);
- String name = model.getName();
- models.put(name, model);
- }
-
- public void add(Map<String, OnnxModel> models) {
- models.values().forEach(this::add);
- }
-
- public OnnxModel get(String name) {
- var model = models.get(name);
- if (model != null) return model;
- if (owner.isPresent() && owner.get().inherited().isPresent())
- return owner.get().inherited().get().onnxModels().get(name);
- return null;
- }
-
- public boolean has(String name) {
- boolean has = models.containsKey(name);
- if (has) return true;
- if (owner.isPresent() && owner.get().inherited().isPresent())
- return owner.get().inherited().get().onnxModels().has(name);
- return false;
- }
-
- public Map<String, OnnxModel> asMap() {
- // Shortcuts
- if (owner.isEmpty() || owner.get().inherited().isEmpty()) return Collections.unmodifiableMap(models);
- if (models.isEmpty()) return owner.get().inherited().get().onnxModels().asMap();
-
- var allModels = new HashMap<>(owner.get().inherited().get().onnxModels().asMap());
- allModels.putAll(models);
- return Collections.unmodifiableMap(allModels);
- }
-
-}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index fd7e7e57f00..07f3048af04 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -118,6 +118,8 @@ public class RankProfile implements Cloneable {
private Map<Reference, Constant> constants = new HashMap<>();
+ private Map<String, OnnxModel> onnxModels = new HashMap<>();
+
private Set<String> filterFields = new HashSet<>();
private final RankProfileRegistry rankProfileRegistry;
@@ -128,8 +130,6 @@ public class RankProfile implements Cloneable {
private Boolean strict;
- /** Global onnx models not tied to a schema */
- private final OnnxModels onnxModels;
private final ApplicationPackage applicationPackage;
private final DeployLogger deployLogger;
@@ -144,7 +144,6 @@ public class RankProfile implements Cloneable {
public RankProfile(String name, Schema schema, RankProfileRegistry rankProfileRegistry) {
this.name = Objects.requireNonNull(name, "name cannot be null");
this.schema = Objects.requireNonNull(schema, "schema cannot be null");
- this.onnxModels = null;
this.rankProfileRegistry = rankProfileRegistry;
this.applicationPackage = schema.applicationPackage();
this.deployLogger = schema.getDeployLogger();
@@ -156,11 +155,10 @@ public class RankProfile implements Cloneable {
* @param name the name of the new profile
*/
public RankProfile(String name, ApplicationPackage applicationPackage, DeployLogger deployLogger,
- RankProfileRegistry rankProfileRegistry, OnnxModels onnxModels) {
+ RankProfileRegistry rankProfileRegistry) {
this.name = Objects.requireNonNull(name, "name cannot be null");
this.schema = null;
this.rankProfileRegistry = rankProfileRegistry;
- this.onnxModels = onnxModels;
this.applicationPackage = applicationPackage;
this.deployLogger = deployLogger;
}
@@ -175,10 +173,6 @@ public class RankProfile implements Cloneable {
return applicationPackage;
}
- public Map<String, OnnxModel> onnxModels() {
- return schema != null ? schema.onnxModels().asMap() : onnxModels.asMap();
- }
-
private Stream<ImmutableSDField> allFields() {
if (schema == null) return Stream.empty();
if (allFieldsList == null) {
@@ -411,6 +405,9 @@ public class RankProfile implements Cloneable {
constants.put(constant.name(), constant);
}
+ /** Returns an unmodifiable view of the constants declared in this */
+ public Map<Reference, Constant> declaredConstants() { return Collections.unmodifiableMap(constants); }
+
/** Returns an unmodifiable view of the constants available in this */
public Map<Reference, Constant> constants() {
Map<Reference, Constant> allConstants = new HashMap<>();
@@ -430,8 +427,31 @@ public class RankProfile implements Cloneable {
return allConstants;
}
- /** Returns an unmodifiable view of the constants declared in this */
- public Map<Reference, Constant> declaredConstants() { return Collections.unmodifiableMap(constants); }
+ public void add(OnnxModel model) {
+ onnxModels.put(model.getName(), model);
+ }
+
+ /** Returns an unmodifiable map of the onnx models declared in this. */
+ public Map<String, OnnxModel> declaredOnnxModels() { return onnxModels; }
+
+ /** Returns an unmodifiable map of the onnx models available in this. */
+ public Map<String, OnnxModel> onnxModels() {
+ Map<String, OnnxModel> allModels = new HashMap<>();
+ for (var inheritedProfile : inherited()) {
+ for (var model : inheritedProfile.onnxModels().values()) {
+ if (allModels.containsKey(model.getName()))
+ throw new IllegalArgumentException(model + "' is present in " +
+ inheritedProfile + " inherited by " +
+ this + ", but is also present in another profile inherited by it");
+ allModels.put(model.getName(), model);
+ }
+ }
+
+ if (schema != null)
+ allModels.putAll(schema.onnxModels());
+ allModels.putAll(onnxModels);
+ return allModels;
+ }
public void addAttributeType(String attributeName, String attributeType) {
attributeTypes.addType(attributeName, attributeType);
@@ -1064,10 +1084,8 @@ public class RankProfile implements Cloneable {
}
// Add output types for ONNX models
- for (Map.Entry<String, OnnxModel> entry : onnxModels().entrySet()) {
- String modelName = entry.getKey();
- OnnxModel model = entry.getValue();
- Arguments args = new Arguments(new ReferenceNode(modelName));
+ for (var model : onnxModels().values()) {
+ Arguments args = new Arguments(new ReferenceNode(model.getName()));
Map<String, TensorType> inputTypes = resolveOnnxInputTypes(model, context);
TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), inputTypes);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Schema.java b/config-model/src/main/java/com/yahoo/searchdefinition/Schema.java
index 147fee05820..8139d46cc0a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/Schema.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/Schema.java
@@ -9,6 +9,7 @@ import com.yahoo.config.model.application.provider.BaseDeployLogger;
import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.document.DataTypeName;
import com.yahoo.document.Field;
+import com.yahoo.searchdefinition.derived.FileDistributedOnnxModels;
import com.yahoo.searchdefinition.derived.SummaryClass;
import com.yahoo.searchdefinition.document.Attribute;
import com.yahoo.searchdefinition.document.ImmutableSDField;
@@ -27,6 +28,7 @@ import java.io.Reader;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
@@ -90,7 +92,8 @@ public class Schema implements ImmutableSchema {
// TODO: Remove on Vespa 9: Should always be in a rank profile
private final Map<Reference, RankProfile.Constant> constants = new LinkedHashMap<>();
- private final OnnxModels onnxModels;
+ // TODO: Remove on Vespa 9: Should always be in a rank profile
+ private final Map<String, OnnxModel> onnxModels = new LinkedHashMap<>();
/** All imported fields of this (and parent schemas) */
// TODO: Use empty, not optional
@@ -150,7 +153,6 @@ public class Schema implements ImmutableSchema {
this.properties = properties;
this.documentsOnly = documentsOnly;
largeRankExpressions = new LargeRankExpressions(fileRegistry);
- onnxModels = new OnnxModels(fileRegistry, Optional.of(this));
}
/**
@@ -230,6 +232,10 @@ public class Schema implements ImmutableSchema {
constants.put(constant.name(), constant);
}
+ /** Returns an unmodifiable map of the constants declared in this. */
+ public Map<Reference, RankProfile.Constant> declaredConstants() { return constants; }
+
+ /** Returns an unmodifiable map of the constants available in this. */
@Override
public Map<Reference, RankProfile.Constant> constants() {
if (inherited().isEmpty()) return Collections.unmodifiableMap(constants);
@@ -240,8 +246,23 @@ public class Schema implements ImmutableSchema {
return allConstants;
}
+ public void add(OnnxModel model) {
+ onnxModels.put(model.getName(), model);
+ }
+
+ /** Returns an unmodifiable map of the onnx models declared in this. */
+ public Map<String, OnnxModel> declaredOnnxModels() { return onnxModels; }
+
+ /** Returns an unmodifiable map of the onnx models available in this. */
@Override
- public OnnxModels onnxModels() { return onnxModels; }
+ public Map<String, OnnxModel> onnxModels() {
+ if (inherited().isEmpty()) return Collections.unmodifiableMap(onnxModels);
+ if (onnxModels.isEmpty()) return inherited().get().onnxModels();
+
+ Map<String, OnnxModel> allModels = new LinkedHashMap<>(inherited().get().onnxModels());
+ allModels.putAll(onnxModels);
+ return allModels;
+ }
public Optional<TemporaryImportedFields> temporaryImportedFields() {
return temporaryImportedFields;
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
index 50029613b2e..c7b9b94a4b2 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
@@ -82,8 +82,7 @@ public class DerivedConfiguration implements AttributesConfig.Producer {
summaries = new Summaries(schema, deployState.getDeployLogger(), deployState.getProperties().featureFlags());
summaryMap = new SummaryMap(schema);
juniperrc = new Juniperrc(schema);
- rankProfileList = new RankProfileList(schema, schema.constants(), schema.rankExpressionFiles(),
- schema.onnxModels(), attributeFields, deployState);
+ rankProfileList = new RankProfileList(schema, schema.rankExpressionFiles(), attributeFields, deployState);
indexingScript = new IndexingScript(schema);
indexInfo = new IndexInfo(schema);
schemaInfo = new SchemaInfo(schema, deployState.rankProfileRegistry(), summaries, summaryMap);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java
new file mode 100644
index 00000000000..a1310d6c18e
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java
@@ -0,0 +1,63 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.derived;
+
+import com.yahoo.config.application.api.FileRegistry;
+import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchdefinition.Schema;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Optional;
+import java.util.logging.Logger;
+
+/**
+ * ONNX models distributed as files.
+ *
+ * @author bratseth
+ */
+public class FileDistributedOnnxModels {
+
+ private static final Logger log = Logger.getLogger(FileDistributedOnnxModels.class.getName());
+
+ private final Map<String, OnnxModel> models;
+
+ public FileDistributedOnnxModels(FileRegistry fileRegistry, Collection<OnnxModel> models) {
+ Map<String, OnnxModel> distributableModels = new LinkedHashMap<>();
+ for (var model : models) {
+ model.validate();
+ model.register(fileRegistry);
+ distributableModels.put(model.getName(), model);
+ }
+ this.models = Collections.unmodifiableMap(distributableModels);
+ }
+
+ public Map<String, OnnxModel> asMap() { return models; }
+
+ public void getConfig(OnnxModelsConfig.Builder builder) {
+ for (OnnxModel model : models.values()) {
+ if ("".equals(model.getFileReference()))
+ log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way
+ else {
+ OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder();
+ modelBuilder.dry_run_on_setup(true);
+ modelBuilder.name(model.getName());
+ modelBuilder.fileref(model.getFileReference());
+ model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
+ model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
+ if (model.getStatelessExecutionMode().isPresent())
+ modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get());
+ if (model.getStatelessInterOpThreads().isPresent())
+ modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get());
+ if (model.getStatelessIntraOpThreads().isPresent())
+ modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get());
+
+ builder.model(modelBuilder);
+ }
+ }
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index 65c0963856a..7384f98b121 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -5,9 +5,8 @@ import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels;
import com.yahoo.config.model.api.ModelContext;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.search.query.profile.QueryProfileRegistry;
-import com.yahoo.searchdefinition.OnnxModel;
-import com.yahoo.searchdefinition.OnnxModels;
import com.yahoo.searchdefinition.LargeRankExpressions;
+import com.yahoo.searchdefinition.OnnxModel;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.vespa.config.search.RankProfilesConfig;
@@ -23,12 +22,10 @@ import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
-import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
-import java.util.logging.Logger;
/**
* The derived rank profiles of a schema
@@ -37,19 +34,17 @@ import java.util.logging.Logger;
*/
public class RankProfileList extends Derived implements RankProfilesConfig.Producer {
- private static final Logger log = Logger.getLogger(RankProfileList.class.getName());
-
private final Map<String, RawRankProfile> rankProfiles;
private final FileDistributedConstants constants;
private final LargeRankExpressions largeRankExpressions;
- private final OnnxModels onnxModels;
+ private final FileDistributedOnnxModels onnxModels;
public static final RankProfileList empty = new RankProfileList();
private RankProfileList() {
constants = new FileDistributedConstants(null, List.of());
largeRankExpressions = new LargeRankExpressions(null);
- onnxModels = new OnnxModels(null, Optional.empty());
+ onnxModels = new FileDistributedOnnxModels(null, List.of());
rankProfiles = Map.of();
}
@@ -60,46 +55,16 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
* @param attributeFields the attribute fields to create a ranking for
*/
public RankProfileList(Schema schema,
- Map<Reference, RankProfile.Constant> constantsFromSchema,
LargeRankExpressions largeRankExpressions,
- OnnxModels onnxModels,
AttributeFields attributeFields,
DeployState deployState) {
setName(schema == null ? "default" : schema.getName());
this.largeRankExpressions = largeRankExpressions;
- this.onnxModels = onnxModels; // as ONNX models come from parsing rank expressions
this.rankProfiles = deriveRankProfiles(schema, attributeFields, deployState);
- this.constants = deriveFileDistributedConstants(schema, constantsFromSchema, rankProfiles.values(), deployState);
- }
-
- private static FileDistributedConstants deriveFileDistributedConstants(Schema schema,
- Map<Reference, RankProfile.Constant> constantsFromSchema,
- Collection<RawRankProfile> rankProfiles,
- DeployState deployState) {
- Map<Reference, RankProfile.Constant> allFileConstants = new HashMap<>();
- addFileConstants(constantsFromSchema.values(), allFileConstants, schema != null ? schema.toString() : "[global]");
- for (var profile : rankProfiles)
- addFileConstants(profile.compiled().constants().values(), allFileConstants, profile.toString());
- return new FileDistributedConstants(deployState.getFileRegistry(), allFileConstants.values());
+ this.constants = deriveFileDistributedConstants(schema, rankProfiles.values(), deployState);
+ this.onnxModels = deriveFileDistributedOnnxModels(schema, rankProfiles.values(), deployState);
}
- private static void addFileConstants(Collection<RankProfile.Constant> source,
- Map<Reference, RankProfile.Constant> destination,
- String sourceName) {
- for (var constant : source) {
- if (constant.valuePath().isEmpty()) continue;
- var existing = destination.get(constant.name());
- if ( existing != null && ! constant.equals(existing)) {
- throw new IllegalArgumentException("Duplicate constants: " + sourceName + " have " + constant +
- ", but we already have " + existing +
- ": Value reference constants must be unique across all rank profiles/models");
- }
- destination.put(constant.name(), constant);
- }
- }
-
- public FileDistributedConstants constants() { return constants; }
-
private boolean areDependenciesReady(RankProfile rank, RankProfileRegistry registry, Set<String> processedProfiles) {
return rank.inheritedNames().isEmpty() ||
processedProfiles.containsAll(rank.inheritedNames()) ||
@@ -132,7 +97,6 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
rawRankProfiles.putAll(processRankProfiles(ready,
deployState.getQueryProfiles().getRegistry(),
deployState.getImportedModels(),
- schema,
attributeFields,
deployState.getProperties(),
deployState.getExecutor()));
@@ -141,20 +105,15 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
return rawRankProfiles;
}
- private Map<String, RawRankProfile> processRankProfiles(List<RankProfile> ready,
+ private Map<String, RawRankProfile> processRankProfiles(List<RankProfile> profiles,
QueryProfileRegistry queryProfiles,
ImportedMlModels importedModels,
- Schema schema,
AttributeFields attributeFields,
ModelContext.Properties deployProperties,
ExecutorService executor) {
Map<String, Future<RawRankProfile>> futureRawRankProfiles = new LinkedHashMap<>();
- for (RankProfile rank : ready) {
- if (schema == null) {
- onnxModels.add(rank.onnxModels());
- }
-
- futureRawRankProfiles.put(rank.name(), executor.submit(() -> new RawRankProfile(rank, largeRankExpressions, queryProfiles, importedModels,
+ for (RankProfile profile : profiles) {
+ futureRawRankProfiles.put(profile.name(), executor.submit(() -> new RawRankProfile(profile, largeRankExpressions, queryProfiles, importedModels,
attributeFields, deployProperties)));
}
try {
@@ -169,19 +128,63 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
}
}
- public OnnxModels getOnnxModels() {
- return onnxModels;
+ private static FileDistributedConstants deriveFileDistributedConstants(Schema schema,
+ Collection<RawRankProfile> rankProfiles,
+ DeployState deployState) {
+ Map<Reference, RankProfile.Constant> allFileConstants = new HashMap<>();
+ addFileConstants(schema != null ? schema.constants().values() : List.of(),
+ allFileConstants,
+ schema != null ? schema.toString() : "[global]");
+ for (var profile : rankProfiles)
+ addFileConstants(profile.compiled().constants().values(), allFileConstants, profile.toString());
+ return new FileDistributedConstants(deployState.getFileRegistry(), allFileConstants.values());
}
- public Map<String, RawRankProfile> getRankProfiles() {
- return rankProfiles;
+ private static void addFileConstants(Collection<RankProfile.Constant> source,
+ Map<Reference, RankProfile.Constant> destination,
+ String sourceName) {
+ for (var constant : source) {
+ if (constant.valuePath().isEmpty()) continue;
+ var existing = destination.get(constant.name());
+ if ( existing != null && ! constant.equals(existing)) {
+ throw new IllegalArgumentException("Duplicate constants: " + sourceName + " have " + constant +
+ ", but we already have " + existing +
+ ": Value reference constants must be unique across all rank profiles/models");
+ }
+ destination.put(constant.name(), constant);
+ }
}
- /** Returns the raw rank profile with the given name, or null if it is not present */
- public RawRankProfile getRankProfile(String name) {
- return rankProfiles.get(name);
+ private static FileDistributedOnnxModels deriveFileDistributedOnnxModels(Schema schema,
+ Collection<RawRankProfile> rankProfiles,
+ DeployState deployState) {
+ Map<String, OnnxModel> allModels = new HashMap<>();
+ addOnnxModels(schema != null ? schema.onnxModels().values() : List.of(),
+ allModels,
+ schema != null ? schema.toString() : "[global]");
+ for (var profile : rankProfiles)
+ addOnnxModels(profile.compiled().onnxModels().values(), allModels, profile.toString());
+ return new FileDistributedOnnxModels(deployState.getFileRegistry(), allModels.values());
}
+ private static void addOnnxModels(Collection<OnnxModel> source,
+ Map<String, OnnxModel> destination,
+ String sourceName) {
+ for (var model : source) {
+ var existing = destination.get(model.getName());
+ if ( existing != null && ! model.equals(existing)) {
+ throw new IllegalArgumentException("Duplicate onnx model: " + sourceName + " have " + model +
+ ", but we already have " + existing +
+ ": Onnx models must be unique across all rank profiles/models");
+ }
+ destination.put(model.getName(), model);
+ }
+ }
+
+ public Map<String, RawRankProfile> getRankProfiles() { return rankProfiles; }
+ public FileDistributedConstants constants() { return constants; }
+ public FileDistributedOnnxModels getOnnxModels() { return onnxModels; }
+
@Override
public String getDerivedName() { return "rank-profiles"; }
@@ -201,25 +204,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
}
public void getConfig(OnnxModelsConfig.Builder builder) {
- for (OnnxModel model : onnxModels.asMap().values()) {
- if ("".equals(model.getFileReference()))
- log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way
- else {
- OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder();
- modelBuilder.dry_run_on_setup(true);
- modelBuilder.name(model.getName());
- modelBuilder.fileref(model.getFileReference());
- model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
- model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
- if (model.getStatelessExecutionMode().isPresent())
- modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get());
- if (model.getStatelessInterOpThreads().isPresent())
- modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get());
- if (model.getStatelessIntraOpThreads().isPresent())
- modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get());
-
- builder.model(modelBuilder);
- }
- }
+ onnxModels.getConfig(builder);
}
+
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index 2189d9d97db..dcd5019cf58 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -483,7 +483,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
private void deriveOnnxModelFunctionsAndFeatures(RankProfile rankProfile) {
if (rankProfile.schema() == null) return;
- if (rankProfile.schema().onnxModels().asMap().isEmpty()) return;
+ if (rankProfile.onnxModels().isEmpty()) return;
replaceOnnxFunctionInputs(rankProfile);
replaceImplicitOnnxConfigFeatures(summaryFeatures, rankProfile);
replaceImplicitOnnxConfigFeatures(matchFeatures, rankProfile);
@@ -492,7 +492,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
private void replaceOnnxFunctionInputs(RankProfile rankProfile) {
Set<String> functionNames = rankProfile.getFunctions().keySet();
if (functionNames.isEmpty()) return;
- for (OnnxModel onnxModel: rankProfile.schema().onnxModels().asMap().values()) {
+ for (OnnxModel onnxModel: rankProfile.onnxModels().values()) {
for (Map.Entry<String, String> mapping : onnxModel.getInputMap().entrySet()) {
String source = mapping.getValue();
if (functionNames.contains(source)) {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedSchemas.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedSchemas.java
index e56245d1332..4c32f11c20d 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedSchemas.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedSchemas.java
@@ -207,12 +207,10 @@ public class ConvertParsedSchemas {
if (documentsOnly) {
return; // skip ranking-only content, not used for document type generation
}
- for (var constant : parsed.getConstants()) {
+ for (var constant : parsed.getConstants())
schema.add(constant);
- }
- for (var onnxModel : parsed.getOnnxModels()) {
- schema.onnxModels().add(onnxModel);
- }
+ for (var onnxModel : parsed.getOnnxModels())
+ schema.add(onnxModel);
rankProfileRegistry.add(new DefaultRankProfile(schema, rankProfileRegistry));
rankProfileRegistry.add(new UnrankedRankProfile(schema, rankProfileRegistry));
var rankConverter = new ConvertParsedRanking(rankProfileRegistry);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
index fdbde08d926..19fbc116558 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
@@ -86,10 +86,8 @@ public class OnnxModelConfigGenerator extends Processor {
}
OnnxModel onnxModel = schema.onnxModels().get(modelConfigName);
- if (onnxModel == null) {
- onnxModel = new OnnxModel(modelConfigName, path);
- schema.onnxModels().add(onnxModel);
- }
+ if (onnxModel == null)
+ schema.add(new OnnxModel(modelConfigName, path));
}
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
index 4153cca4b5b..83e0c367292 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
@@ -4,6 +4,7 @@ package com.yahoo.searchdefinition.processing;
import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.Schema;
import com.yahoo.vespa.model.container.search.QueryProfiles;
@@ -28,9 +29,11 @@ public class OnnxModelTypeResolver extends Processor {
@Override
public void process(boolean validate, boolean documentsOnly) {
if (documentsOnly) return;
- for (OnnxModel onnxModel : schema.onnxModels().asMap().values()) {
- OnnxModelInfo onnxModelInfo = OnnxModelInfo.load(onnxModel.getFileName(), schema.applicationPackage());
- onnxModel.setModelInfo(onnxModelInfo);
+ for (OnnxModel onnxModel : schema.declaredOnnxModels().values())
+ onnxModel.setModelInfo(OnnxModelInfo.load(onnxModel.getFileName(), schema.applicationPackage()));
+ for (RankProfile profile : rankProfileRegistry.rankProfilesOf(schema)) {
+ for (OnnxModel onnxModel : profile.declaredOnnxModels().values())
+ onnxModel.setModelInfo(OnnxModelInfo.load(onnxModel.getFileName(), schema.applicationPackage()));
}
}