summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2022-05-16 16:27:04 +0200
committerGitHub <noreply@github.com>2022-05-16 16:27:04 +0200
commitcc8d97a8e9f56d139700424e1393ea954723ae2d (patch)
tree9fbf87884087901ba6a586bebad3aa42da822a4c
parent1e1c46f0d579de70447f64a20a148447ccefe862 (diff)
parenta37ed1c28091f234f25c9b3649999821eb7f4802 (diff)
Merge pull request #22620 from vespa-engine/bratseth/models-in-profiles
Bratseth/models in profiles
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSchema.java3
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java67
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java50
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/Schema.java29
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java3
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedConstants.java11
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java63
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java149
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java9
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedRanking.java3
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedSchemas.java8
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedRankProfile.java7
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedSchema.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java26
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java9
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java32
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantValidator.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java2
-rw-r--r--config-model/src/main/javacc/IntermediateParser.jj40
-rw-r--r--config-model/src/test/integration/onnx-model/schemas/test.sd15
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/SchemaTestCase.java10
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/GeminiTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java28
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java2
26 files changed, 303 insertions, 275 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSchema.java b/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSchema.java
index 1c643292a05..3719313179f 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSchema.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSchema.java
@@ -4,6 +4,7 @@ package com.yahoo.searchdefinition;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.model.api.ModelContext;
+import com.yahoo.searchdefinition.derived.FileDistributedOnnxModels;
import com.yahoo.searchdefinition.document.ImmutableSDField;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.document.SDField;
@@ -37,7 +38,7 @@ public interface ImmutableSchema {
ModelContext.Properties getDeployProperties();
Map<Reference, RankProfile.Constant> constants();
LargeRankExpressions rankExpressionFiles();
- OnnxModels onnxModels();
+ Map<String, OnnxModel> onnxModels();
Stream<ImmutableSDField> allImportedFields();
SDDocumentType getDocument();
ImmutableSDField getField(String name);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
deleted file mode 100644
index c9c12100552..00000000000
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
+++ /dev/null
@@ -1,67 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchdefinition;
-
-import com.yahoo.config.application.api.FileRegistry;
-
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Optional;
-
-/**
- * ONNX models tied to a search definition or global.
- *
- * @author lesters
- */
-public class OnnxModels {
-
- private final FileRegistry fileRegistry;
-
- /** The schema this belongs to, or empty if it is global */
- private final Optional<Schema> owner;
-
- private final Map<String, OnnxModel> models = new HashMap<>();
-
- public OnnxModels(FileRegistry fileRegistry, Optional<Schema> owner) {
- this.fileRegistry = fileRegistry;
- this.owner = owner;
- }
-
- public void add(OnnxModel model) {
- model.validate();
- model.register(fileRegistry);
- String name = model.getName();
- models.put(name, model);
- }
-
- public void add(Map<String, OnnxModel> models) {
- models.values().forEach(this::add);
- }
-
- public OnnxModel get(String name) {
- var model = models.get(name);
- if (model != null) return model;
- if (owner.isPresent() && owner.get().inherited().isPresent())
- return owner.get().inherited().get().onnxModels().get(name);
- return null;
- }
-
- public boolean has(String name) {
- boolean has = models.containsKey(name);
- if (has) return true;
- if (owner.isPresent() && owner.get().inherited().isPresent())
- return owner.get().inherited().get().onnxModels().has(name);
- return false;
- }
-
- public Map<String, OnnxModel> asMap() {
- // Shortcuts
- if (owner.isEmpty() || owner.get().inherited().isEmpty()) return Collections.unmodifiableMap(models);
- if (models.isEmpty()) return owner.get().inherited().get().onnxModels().asMap();
-
- var allModels = new HashMap<>(owner.get().inherited().get().onnxModels().asMap());
- allModels.putAll(models);
- return Collections.unmodifiableMap(allModels);
- }
-
-}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index fd7e7e57f00..ec560484513 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -116,7 +116,9 @@ public class RankProfile implements Cloneable {
private Map<Reference, Input> inputs = new LinkedHashMap<>();
- private Map<Reference, Constant> constants = new HashMap<>();
+ private Map<Reference, Constant> constants = new LinkedHashMap<>();
+
+ private Map<String, OnnxModel> onnxModels = new LinkedHashMap<>();
private Set<String> filterFields = new HashSet<>();
@@ -128,8 +130,6 @@ public class RankProfile implements Cloneable {
private Boolean strict;
- /** Global onnx models not tied to a schema */
- private final OnnxModels onnxModels;
private final ApplicationPackage applicationPackage;
private final DeployLogger deployLogger;
@@ -144,7 +144,6 @@ public class RankProfile implements Cloneable {
public RankProfile(String name, Schema schema, RankProfileRegistry rankProfileRegistry) {
this.name = Objects.requireNonNull(name, "name cannot be null");
this.schema = Objects.requireNonNull(schema, "schema cannot be null");
- this.onnxModels = null;
this.rankProfileRegistry = rankProfileRegistry;
this.applicationPackage = schema.applicationPackage();
this.deployLogger = schema.getDeployLogger();
@@ -156,11 +155,10 @@ public class RankProfile implements Cloneable {
* @param name the name of the new profile
*/
public RankProfile(String name, ApplicationPackage applicationPackage, DeployLogger deployLogger,
- RankProfileRegistry rankProfileRegistry, OnnxModels onnxModels) {
+ RankProfileRegistry rankProfileRegistry) {
this.name = Objects.requireNonNull(name, "name cannot be null");
this.schema = null;
this.rankProfileRegistry = rankProfileRegistry;
- this.onnxModels = onnxModels;
this.applicationPackage = applicationPackage;
this.deployLogger = deployLogger;
}
@@ -175,10 +173,6 @@ public class RankProfile implements Cloneable {
return applicationPackage;
}
- public Map<String, OnnxModel> onnxModels() {
- return schema != null ? schema.onnxModels().asMap() : onnxModels.asMap();
- }
-
private Stream<ImmutableSDField> allFields() {
if (schema == null) return Stream.empty();
if (allFieldsList == null) {
@@ -411,6 +405,9 @@ public class RankProfile implements Cloneable {
constants.put(constant.name(), constant);
}
+ /** Returns an unmodifiable view of the constants declared in this */
+ public Map<Reference, Constant> declaredConstants() { return Collections.unmodifiableMap(constants); }
+
/** Returns an unmodifiable view of the constants available in this */
public Map<Reference, Constant> constants() {
Map<Reference, Constant> allConstants = new HashMap<>();
@@ -430,8 +427,31 @@ public class RankProfile implements Cloneable {
return allConstants;
}
- /** Returns an unmodifiable view of the constants declared in this */
- public Map<Reference, Constant> declaredConstants() { return Collections.unmodifiableMap(constants); }
+ public void add(OnnxModel model) {
+ onnxModels.put(model.getName(), model);
+ }
+
+ /** Returns an unmodifiable map of the onnx models declared in this. */
+ public Map<String, OnnxModel> declaredOnnxModels() { return onnxModels; }
+
+ /** Returns an unmodifiable map of the onnx models available in this. */
+ public Map<String, OnnxModel> onnxModels() {
+ Map<String, OnnxModel> allModels = new HashMap<>();
+ for (var inheritedProfile : inherited()) {
+ for (var model : inheritedProfile.onnxModels().values()) {
+ if (allModels.containsKey(model.getName()))
+ throw new IllegalArgumentException(model + "' is present in " +
+ inheritedProfile + " inherited by " +
+ this + ", but is also present in another profile inherited by it");
+ allModels.put(model.getName(), model);
+ }
+ }
+
+ if (schema != null)
+ allModels.putAll(schema.onnxModels());
+ allModels.putAll(onnxModels);
+ return allModels;
+ }
public void addAttributeType(String attributeName, String attributeType) {
attributeTypes.addType(attributeName, attributeType);
@@ -1064,10 +1084,8 @@ public class RankProfile implements Cloneable {
}
// Add output types for ONNX models
- for (Map.Entry<String, OnnxModel> entry : onnxModels().entrySet()) {
- String modelName = entry.getKey();
- OnnxModel model = entry.getValue();
- Arguments args = new Arguments(new ReferenceNode(modelName));
+ for (var model : onnxModels().values()) {
+ Arguments args = new Arguments(new ReferenceNode(model.getName()));
Map<String, TensorType> inputTypes = resolveOnnxInputTypes(model, context);
TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), inputTypes);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Schema.java b/config-model/src/main/java/com/yahoo/searchdefinition/Schema.java
index db105edc9d4..8139d46cc0a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/Schema.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/Schema.java
@@ -9,6 +9,7 @@ import com.yahoo.config.model.application.provider.BaseDeployLogger;
import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.document.DataTypeName;
import com.yahoo.document.Field;
+import com.yahoo.searchdefinition.derived.FileDistributedOnnxModels;
import com.yahoo.searchdefinition.derived.SummaryClass;
import com.yahoo.searchdefinition.document.Attribute;
import com.yahoo.searchdefinition.document.ImmutableSDField;
@@ -27,6 +28,7 @@ import java.io.Reader;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
@@ -90,7 +92,8 @@ public class Schema implements ImmutableSchema {
// TODO: Remove on Vespa 9: Should always be in a rank profile
private final Map<Reference, RankProfile.Constant> constants = new LinkedHashMap<>();
- private final OnnxModels onnxModels;
+ // TODO: Remove on Vespa 9: Should always be in a rank profile
+ private final Map<String, OnnxModel> onnxModels = new LinkedHashMap<>();
/** All imported fields of this (and parent schemas) */
// TODO: Use empty, not optional
@@ -150,7 +153,6 @@ public class Schema implements ImmutableSchema {
this.properties = properties;
this.documentsOnly = documentsOnly;
largeRankExpressions = new LargeRankExpressions(fileRegistry);
- onnxModels = new OnnxModels(fileRegistry, Optional.of(this));
}
/**
@@ -218,7 +220,7 @@ public class Schema implements ImmutableSchema {
*/
public void addDocument(SDDocumentType document) {
if (documentType != null) {
- throw new IllegalArgumentException("Searchdefinition cannot have more than one document");
+ throw new IllegalArgumentException("Schema cannot have more than one document");
}
documentType = document;
}
@@ -230,6 +232,10 @@ public class Schema implements ImmutableSchema {
constants.put(constant.name(), constant);
}
+ /** Returns an unmodifiable map of the constants declared in this. */
+ public Map<Reference, RankProfile.Constant> declaredConstants() { return constants; }
+
+ /** Returns an unmodifiable map of the constants available in this. */
@Override
public Map<Reference, RankProfile.Constant> constants() {
if (inherited().isEmpty()) return Collections.unmodifiableMap(constants);
@@ -240,8 +246,23 @@ public class Schema implements ImmutableSchema {
return allConstants;
}
+ public void add(OnnxModel model) {
+ onnxModels.put(model.getName(), model);
+ }
+
+ /** Returns an unmodifiable map of the onnx models declared in this. */
+ public Map<String, OnnxModel> declaredOnnxModels() { return onnxModels; }
+
+ /** Returns an unmodifiable map of the onnx models available in this. */
@Override
- public OnnxModels onnxModels() { return onnxModels; }
+ public Map<String, OnnxModel> onnxModels() {
+ if (inherited().isEmpty()) return Collections.unmodifiableMap(onnxModels);
+ if (onnxModels.isEmpty()) return inherited().get().onnxModels();
+
+ Map<String, OnnxModel> allModels = new LinkedHashMap<>(inherited().get().onnxModels());
+ allModels.putAll(onnxModels);
+ return allModels;
+ }
public Optional<TemporaryImportedFields> temporaryImportedFields() {
return temporaryImportedFields;
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
index 50029613b2e..c7b9b94a4b2 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
@@ -82,8 +82,7 @@ public class DerivedConfiguration implements AttributesConfig.Producer {
summaries = new Summaries(schema, deployState.getDeployLogger(), deployState.getProperties().featureFlags());
summaryMap = new SummaryMap(schema);
juniperrc = new Juniperrc(schema);
- rankProfileList = new RankProfileList(schema, schema.constants(), schema.rankExpressionFiles(),
- schema.onnxModels(), attributeFields, deployState);
+ rankProfileList = new RankProfileList(schema, schema.rankExpressionFiles(), attributeFields, deployState);
indexingScript = new IndexingScript(schema);
indexInfo = new IndexInfo(schema);
schemaInfo = new SchemaInfo(schema, deployState.rankProfileRegistry(), summaries, summaryMap);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedConstants.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedConstants.java
index 8de86beacdb..433bfb108d6 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedConstants.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedConstants.java
@@ -5,12 +5,12 @@ import com.yahoo.config.application.api.FileRegistry;
import com.yahoo.searchdefinition.DistributableResource;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.tensor.TensorType;
+import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
-import java.util.function.Function;
/**
* Constant values for ranking/model execution tied to a rank profile,
@@ -41,6 +41,15 @@ public class FileDistributedConstants {
/** Returns a read-only map of the constants in this indexed by name. */
public Map<String, DistributableConstant> asMap() { return constants; }
+ public void getConfig(RankingConstantsConfig.Builder builder) {
+ for (var constant : constants.values()) {
+ builder.constant(new RankingConstantsConfig.Constant.Builder()
+ .name(constant.getName())
+ .fileref(constant.getFileReference())
+ .type(constant.getType()));
+ }
+ }
+
public static class DistributableConstant extends DistributableResource {
private final TensorType tensorType;
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java
new file mode 100644
index 00000000000..a1310d6c18e
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java
@@ -0,0 +1,63 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.derived;
+
+import com.yahoo.config.application.api.FileRegistry;
+import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchdefinition.Schema;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Optional;
+import java.util.logging.Logger;
+
+/**
+ * ONNX models distributed as files.
+ *
+ * @author bratseth
+ */
+public class FileDistributedOnnxModels {
+
+ private static final Logger log = Logger.getLogger(FileDistributedOnnxModels.class.getName());
+
+ private final Map<String, OnnxModel> models;
+
+ public FileDistributedOnnxModels(FileRegistry fileRegistry, Collection<OnnxModel> models) {
+ Map<String, OnnxModel> distributableModels = new LinkedHashMap<>();
+ for (var model : models) {
+ model.validate();
+ model.register(fileRegistry);
+ distributableModels.put(model.getName(), model);
+ }
+ this.models = Collections.unmodifiableMap(distributableModels);
+ }
+
+ public Map<String, OnnxModel> asMap() { return models; }
+
+ public void getConfig(OnnxModelsConfig.Builder builder) {
+ for (OnnxModel model : models.values()) {
+ if ("".equals(model.getFileReference()))
+ log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way
+ else {
+ OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder();
+ modelBuilder.dry_run_on_setup(true);
+ modelBuilder.name(model.getName());
+ modelBuilder.fileref(model.getFileReference());
+ model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
+ model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
+ if (model.getStatelessExecutionMode().isPresent())
+ modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get());
+ if (model.getStatelessInterOpThreads().isPresent())
+ modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get());
+ if (model.getStatelessIntraOpThreads().isPresent())
+ modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get());
+
+ builder.model(modelBuilder);
+ }
+ }
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index d415f7d0d6e..081450275d1 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -2,13 +2,11 @@
package com.yahoo.searchdefinition.derived;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels;
-import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.model.api.ModelContext;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.search.query.profile.QueryProfileRegistry;
-import com.yahoo.searchdefinition.OnnxModel;
-import com.yahoo.searchdefinition.OnnxModels;
import com.yahoo.searchdefinition.LargeRankExpressions;
+import com.yahoo.searchdefinition.OnnxModel;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.vespa.config.search.RankProfilesConfig;
@@ -24,13 +22,10 @@ import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
-import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
-import java.util.logging.Level;
-import java.util.logging.Logger;
/**
* The derived rank profiles of a schema
@@ -39,19 +34,17 @@ import java.util.logging.Logger;
*/
public class RankProfileList extends Derived implements RankProfilesConfig.Producer {
- private static final Logger log = Logger.getLogger(RankProfileList.class.getName());
-
private final Map<String, RawRankProfile> rankProfiles;
private final FileDistributedConstants constants;
private final LargeRankExpressions largeRankExpressions;
- private final OnnxModels onnxModels;
+ private final FileDistributedOnnxModels onnxModels;
public static final RankProfileList empty = new RankProfileList();
private RankProfileList() {
constants = new FileDistributedConstants(null, List.of());
largeRankExpressions = new LargeRankExpressions(null);
- onnxModels = new OnnxModels(null, Optional.empty());
+ onnxModels = new FileDistributedOnnxModels(null, List.of());
rankProfiles = Map.of();
}
@@ -62,46 +55,16 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
* @param attributeFields the attribute fields to create a ranking for
*/
public RankProfileList(Schema schema,
- Map<Reference, RankProfile.Constant> constantsFromSchema,
LargeRankExpressions largeRankExpressions,
- OnnxModels onnxModels,
AttributeFields attributeFields,
DeployState deployState) {
setName(schema == null ? "default" : schema.getName());
this.largeRankExpressions = largeRankExpressions;
- this.onnxModels = onnxModels; // as ONNX models come from parsing rank expressions
this.rankProfiles = deriveRankProfiles(schema, attributeFields, deployState);
- this.constants = deriveFileDistributedConstants(schema, constantsFromSchema, rankProfiles.values(), deployState);
- }
-
- private static FileDistributedConstants deriveFileDistributedConstants(Schema schema,
- Map<Reference, RankProfile.Constant> constantsFromSchema,
- Collection<RawRankProfile> rankProfiles,
- DeployState deployState) {
- Map<Reference, RankProfile.Constant> allFileConstants = new HashMap<>();
- addFileConstants(constantsFromSchema.values(), allFileConstants, schema != null ? schema.toString() : "[global]");
- for (var profile : rankProfiles)
- addFileConstants(profile.compiled().constants().values(), allFileConstants, profile.toString());
- return new FileDistributedConstants(deployState.getFileRegistry(), allFileConstants.values());
- }
-
- private static void addFileConstants(Collection<RankProfile.Constant> source,
- Map<Reference, RankProfile.Constant> destination,
- String sourceName) {
- for (var constant : source) {
- if (constant.valuePath().isEmpty()) continue;
- var existing = destination.get(constant.name());
- if ( existing != null && ! constant.equals(existing)) {
- throw new IllegalArgumentException("Duplicate constants: " + sourceName + " have " + constant +
- ", but we already have " + existing +
- ": Value reference constants must be unique across all rank profiles/models");
- }
- destination.put(constant.name(), constant);
- }
+ this.constants = deriveFileDistributedConstants(schema, rankProfiles.values(), deployState);
+ this.onnxModels = deriveFileDistributedOnnxModels(schema, rankProfiles.values(), deployState);
}
- public FileDistributedConstants constants() { return constants; }
-
private boolean areDependenciesReady(RankProfile rank, RankProfileRegistry registry, Set<String> processedProfiles) {
return rank.inheritedNames().isEmpty() ||
processedProfiles.containsAll(rank.inheritedNames()) ||
@@ -134,7 +97,6 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
rawRankProfiles.putAll(processRankProfiles(ready,
deployState.getQueryProfiles().getRegistry(),
deployState.getImportedModels(),
- schema,
attributeFields,
deployState.getProperties(),
deployState.getExecutor()));
@@ -143,20 +105,15 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
return rawRankProfiles;
}
- private Map<String, RawRankProfile> processRankProfiles(List<RankProfile> ready,
+ private Map<String, RawRankProfile> processRankProfiles(List<RankProfile> profiles,
QueryProfileRegistry queryProfiles,
ImportedMlModels importedModels,
- Schema schema,
AttributeFields attributeFields,
ModelContext.Properties deployProperties,
ExecutorService executor) {
Map<String, Future<RawRankProfile>> futureRawRankProfiles = new LinkedHashMap<>();
- for (RankProfile rank : ready) {
- if (schema == null) {
- onnxModels.add(rank.onnxModels());
- }
-
- futureRawRankProfiles.put(rank.name(), executor.submit(() -> new RawRankProfile(rank, largeRankExpressions, queryProfiles, importedModels,
+ for (RankProfile profile : profiles) {
+ futureRawRankProfiles.put(profile.name(), executor.submit(() -> new RawRankProfile(profile, largeRankExpressions, queryProfiles, importedModels,
attributeFields, deployProperties)));
}
try {
@@ -171,19 +128,63 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
}
}
- public OnnxModels getOnnxModels() {
- return onnxModels;
+ private static FileDistributedConstants deriveFileDistributedConstants(Schema schema,
+ Collection<RawRankProfile> rankProfiles,
+ DeployState deployState) {
+ Map<Reference, RankProfile.Constant> allFileConstants = new HashMap<>();
+ addFileConstants(schema != null ? schema.constants().values() : List.of(),
+ allFileConstants,
+ schema != null ? schema.toString() : "[global]");
+ for (var profile : rankProfiles)
+ addFileConstants(profile.compiled().constants().values(), allFileConstants, profile.toString());
+ return new FileDistributedConstants(deployState.getFileRegistry(), allFileConstants.values());
}
- public Map<String, RawRankProfile> getRankProfiles() {
- return rankProfiles;
+ private static void addFileConstants(Collection<RankProfile.Constant> source,
+ Map<Reference, RankProfile.Constant> destination,
+ String sourceName) {
+ for (var constant : source) {
+ if (constant.valuePath().isEmpty()) continue;
+ var existing = destination.get(constant.name());
+ if ( existing != null && ! constant.equals(existing)) {
+ throw new IllegalArgumentException("Duplicate constants: " + sourceName + " have " + constant +
+ ", but we already have " + existing +
+ ": Value reference constants must be unique across all rank profiles/models");
+ }
+ destination.put(constant.name(), constant);
+ }
}
- /** Returns the raw rank profile with the given name, or null if it is not present */
- public RawRankProfile getRankProfile(String name) {
- return rankProfiles.get(name);
+ private static FileDistributedOnnxModels deriveFileDistributedOnnxModels(Schema schema,
+ Collection<RawRankProfile> rankProfiles,
+ DeployState deployState) {
+ Map<String, OnnxModel> allModels = new LinkedHashMap<>();
+ addOnnxModels(schema != null ? schema.onnxModels().values() : List.of(),
+ allModels,
+ schema != null ? schema.toString() : "[global]");
+ for (var profile : rankProfiles)
+ addOnnxModels(profile.compiled().onnxModels().values(), allModels, profile.toString());
+ return new FileDistributedOnnxModels(deployState.getFileRegistry(), allModels.values());
}
+ private static void addOnnxModels(Collection<OnnxModel> source,
+ Map<String, OnnxModel> destination,
+ String sourceName) {
+ for (var model : source) {
+ var existing = destination.get(model.getName());
+ if ( existing != null && ! model.equals(existing)) {
+ throw new IllegalArgumentException("Duplicate onnx model: " + sourceName + " have " + model +
+ ", but we already have " + existing +
+ ": Onnx models must be unique across all rank profiles/models");
+ }
+ destination.put(model.getName(), model);
+ }
+ }
+
+ public Map<String, RawRankProfile> getRankProfiles() { return rankProfiles; }
+ public FileDistributedConstants constants() { return constants; }
+ public FileDistributedOnnxModels getOnnxModels() { return onnxModels; }
+
@Override
public String getDerivedName() { return "rank-profiles"; }
@@ -199,37 +200,11 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
}
public void getConfig(RankingConstantsConfig.Builder builder) {
- for (var constant : constants.asMap().values()) {
- if ("".equals(constant.getFileReference()))
- log.warning("Illegal file reference " + constant); // Let tests pass ... we should find a better way
- else
- builder.constant(new RankingConstantsConfig.Constant.Builder()
- .name(constant.getName())
- .fileref(constant.getFileReference())
- .type(constant.getType()));
- }
+ constants.getConfig(builder);
}
public void getConfig(OnnxModelsConfig.Builder builder) {
- for (OnnxModel model : onnxModels.asMap().values()) {
- if ("".equals(model.getFileReference()))
- log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way
- else {
- OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder();
- modelBuilder.dry_run_on_setup(true);
- modelBuilder.name(model.getName());
- modelBuilder.fileref(model.getFileReference());
- model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
- model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
- if (model.getStatelessExecutionMode().isPresent())
- modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get());
- if (model.getStatelessInterOpThreads().isPresent())
- modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get());
- if (model.getStatelessIntraOpThreads().isPresent())
- modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get());
-
- builder.model(modelBuilder);
- }
- }
+ onnxModels.getConfig(builder);
}
+
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index 2189d9d97db..dcd5019cf58 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -483,7 +483,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
private void deriveOnnxModelFunctionsAndFeatures(RankProfile rankProfile) {
if (rankProfile.schema() == null) return;
- if (rankProfile.schema().onnxModels().asMap().isEmpty()) return;
+ if (rankProfile.onnxModels().isEmpty()) return;
replaceOnnxFunctionInputs(rankProfile);
replaceImplicitOnnxConfigFeatures(summaryFeatures, rankProfile);
replaceImplicitOnnxConfigFeatures(matchFeatures, rankProfile);
@@ -492,7 +492,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
private void replaceOnnxFunctionInputs(RankProfile rankProfile) {
Set<String> functionNames = rankProfile.getFunctions().keySet();
if (functionNames.isEmpty()) return;
- for (OnnxModel onnxModel: rankProfile.schema().onnxModels().asMap().values()) {
+ for (OnnxModel onnxModel: rankProfile.onnxModels().values()) {
for (Map.Entry<String, String> mapping : onnxModel.getInputMap().entrySet()) {
String source = mapping.getValue();
if (functionNames.contains(source)) {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
index 71493df357c..58a9c78254a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
@@ -53,9 +53,8 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans
return transformFeature(feature, context.rankProfile());
}
- public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile rankProfile) {
- ImmutableSchema search = rankProfile.schema();
- final String featureName = feature.getName();
+ public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile profile) {
+ String featureName = feature.getName();
if ( ! featureName.equals("onnxModel") && ! featureName.equals("onnx")) return feature;
Arguments arguments = feature.getArguments();
@@ -71,11 +70,11 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans
// ONNX file that was transformed to Vespa ranking expressions. We then assume it is in the model store.
String modelConfigName = getModelConfigName(feature.reference());
- OnnxModel onnxModel = search.onnxModels().get(modelConfigName);
+ OnnxModel onnxModel = profile.onnxModels().get(modelConfigName);
if (onnxModel == null) {
String path = asString(arguments.expressions().get(0));
ModelName modelName = new ModelName(null, Path.fromString(path), true);
- ConvertedModel convertedModel = ConvertedModel.fromStore(search.applicationPackage(), modelName, path, rankProfile);
+ ConvertedModel convertedModel = ConvertedModel.fromStore(profile.schema().applicationPackage(), modelName, path, profile);
FeatureArguments featureArguments = new FeatureArguments(arguments);
return convertedModel.expression(featureArguments, null);
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedRanking.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedRanking.java
index f772c5fe903..47d770f609e 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedRanking.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedRanking.java
@@ -43,6 +43,9 @@ public class ConvertParsedRanking {
for (var constant : parsed.getConstants().values())
profile.add(constant);
+ for (var onnxModel : parsed.getOnnxModels())
+ profile.add(onnxModel);
+
for (var input : parsed.getInputs().entrySet())
profile.addInput(input.getKey(), input.getValue());
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedSchemas.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedSchemas.java
index e56245d1332..4c32f11c20d 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedSchemas.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ConvertParsedSchemas.java
@@ -207,12 +207,10 @@ public class ConvertParsedSchemas {
if (documentsOnly) {
return; // skip ranking-only content, not used for document type generation
}
- for (var constant : parsed.getConstants()) {
+ for (var constant : parsed.getConstants())
schema.add(constant);
- }
- for (var onnxModel : parsed.getOnnxModels()) {
- schema.onnxModels().add(onnxModel);
- }
+ for (var onnxModel : parsed.getOnnxModels())
+ schema.add(onnxModel);
rankProfileRegistry.add(new DefaultRankProfile(schema, rankProfileRegistry));
rankProfileRegistry.add(new UnrankedRankProfile(schema, rankProfileRegistry));
var rankConverter = new ConvertParsedRanking(rankProfileRegistry);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedRankProfile.java
index 0ade3bfd76b..8f0f92c4027 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedRankProfile.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.parser;
+import com.yahoo.searchdefinition.OnnxModel;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankProfile.MatchPhaseSettings;
import com.yahoo.searchdefinition.RankProfile.MutateOperation;
@@ -54,6 +55,7 @@ class ParsedRankProfile extends ParsedBlock {
private final Map<String, List<String>> rankProperties = new LinkedHashMap<>();
private final Map<Reference, RankProfile.Constant> constants = new LinkedHashMap<>();
private final Map<Reference, RankProfile.Input> inputs = new LinkedHashMap<>();
+ private final List<OnnxModel> onnxModels = new ArrayList<>();
ParsedRankProfile(String name) {
super(name, "rank-profile");
@@ -85,6 +87,7 @@ class ParsedRankProfile extends ParsedBlock {
Map<String, List<String>> getRankProperties() { return Collections.unmodifiableMap(rankProperties); }
Map<Reference, RankProfile.Constant> getConstants() { return Collections.unmodifiableMap(constants); }
Map<Reference, RankProfile.Input> getInputs() { return Collections.unmodifiableMap(inputs); }
+ List<OnnxModel> getOnnxModels() { return List.copyOf(onnxModels); }
Optional<String> getInheritedSummaryFeatures() { return Optional.ofNullable(this.inheritedSummaryFeatures); }
Optional<String> getSecondPhaseExpression() { return Optional.ofNullable(this.secondPhaseExpression); }
@@ -111,6 +114,10 @@ class ParsedRankProfile extends ParsedBlock {
inputs.put(name, input);
}
+ void add(OnnxModel model) {
+ onnxModels.add(model);
+ }
+
void addFieldRankFilter(String field, boolean filter) {
fieldsRankFilter.put(field, filter);
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedSchema.java b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedSchema.java
index 2bc10554b25..4c102594479 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedSchema.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/parser/ParsedSchema.java
@@ -123,7 +123,7 @@ public class ParsedSchema extends ParsedBlock {
extraIndexes.put(idxName, index);
}
- void addOnnxModel(OnnxModel model) {
+ void add(OnnxModel model) {
onnxModels.add(model);
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
index fdbde08d926..70ce051bb21 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
@@ -23,7 +23,7 @@ import java.util.Map;
*
* onnx("files/model.onnx", "path/to/output:1")
*
- * And generates an "onnx-model" configuration as if it was defined in the schema:
+ * And generates an "onnx-model" configuration as if it was defined in the profile:
*
* onnx-model files_model_onnx {
* file: "files/model.onnx"
@@ -45,31 +45,31 @@ public class OnnxModelConfigGenerator extends Processor {
if (documentsOnly) return;
for (RankProfile profile : rankProfileRegistry.rankProfilesOf(schema)) {
if (profile.getFirstPhaseRanking() != null) {
- process(profile.getFirstPhaseRanking().getRoot());
+ process(profile.getFirstPhaseRanking().getRoot(), profile);
}
if (profile.getSecondPhaseRanking() != null) {
- process(profile.getSecondPhaseRanking().getRoot());
+ process(profile.getSecondPhaseRanking().getRoot(), profile);
}
for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) {
- process(function.getValue().function().getBody().getRoot());
+ process(function.getValue().function().getBody().getRoot(), profile);
}
for (ReferenceNode feature : profile.getSummaryFeatures()) {
- process(feature);
+ process(feature, profile);
}
}
}
- private void process(ExpressionNode node) {
+ private void process(ExpressionNode node, RankProfile profile) {
if (node instanceof ReferenceNode) {
- process((ReferenceNode)node);
+ process((ReferenceNode)node, profile);
} else if (node instanceof CompositeNode) {
for (ExpressionNode child : ((CompositeNode) node).children()) {
- process(child);
+ process(child, profile);
}
}
}
- private void process(ReferenceNode feature) {
+ private void process(ReferenceNode feature, RankProfile profile) {
if (feature.getName().equals("onnxModel") || feature.getName().equals("onnx")) {
if (feature.getArguments().size() > 0) {
if (feature.getArguments().expressions().get(0) instanceof ConstantNode) {
@@ -85,11 +85,9 @@ public class OnnxModelConfigGenerator extends Processor {
}
}
- OnnxModel onnxModel = schema.onnxModels().get(modelConfigName);
- if (onnxModel == null) {
- onnxModel = new OnnxModel(modelConfigName, path);
- schema.onnxModels().add(onnxModel);
- }
+ OnnxModel onnxModel = profile.onnxModels().get(modelConfigName);
+ if (onnxModel == null)
+ profile.add(new OnnxModel(modelConfigName, path));
}
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
index 4153cca4b5b..83e0c367292 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
@@ -4,6 +4,7 @@ package com.yahoo.searchdefinition.processing;
import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.Schema;
import com.yahoo.vespa.model.container.search.QueryProfiles;
@@ -28,9 +29,11 @@ public class OnnxModelTypeResolver extends Processor {
@Override
public void process(boolean validate, boolean documentsOnly) {
if (documentsOnly) return;
- for (OnnxModel onnxModel : schema.onnxModels().asMap().values()) {
- OnnxModelInfo onnxModelInfo = OnnxModelInfo.load(onnxModel.getFileName(), schema.applicationPackage());
- onnxModel.setModelInfo(onnxModelInfo);
+ for (OnnxModel onnxModel : schema.declaredOnnxModels().values())
+ onnxModel.setModelInfo(OnnxModelInfo.load(onnxModel.getFileName(), schema.applicationPackage()));
+ for (RankProfile profile : rankProfileRegistry.rankProfilesOf(schema)) {
+ for (OnnxModel onnxModel : profile.declaredOnnxModels().values())
+ onnxModel.setModelInfo(OnnxModelInfo.load(onnxModel.getFileName(), schema.applicationPackage()));
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
index 03fc7592b40..149326130be 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
@@ -32,7 +32,7 @@ import com.yahoo.container.QrConfig;
import com.yahoo.path.Path;
import com.yahoo.searchdefinition.LargeRankExpressions;
import com.yahoo.searchdefinition.OnnxModel;
-import com.yahoo.searchdefinition.OnnxModels;
+import com.yahoo.searchdefinition.derived.FileDistributedOnnxModels;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.derived.AttributeFields;
@@ -175,11 +175,9 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
VespaModelBuilder builder = new VespaDomBuilder();
root = builder.getRoot(VespaModel.ROOT_CONFIGID, deployState, this);
- createGlobalRankProfiles(deployState, deployState.getFileRegistry());
+ createGlobalRankProfiles(deployState);
rankProfileList = new RankProfileList(null, // null search -> global
- Map.of(),
new LargeRankExpressions(deployState.getFileRegistry()),
- new OnnxModels(deployState.getFileRegistry(), Optional.empty()),
AttributeFields.empty,
deployState);
@@ -274,7 +272,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
* Creates a rank profile not attached to any search definition, for each imported model in the application package,
* and adds it to the given rank profile registry.
*/
- private void createGlobalRankProfiles(DeployState deployState, FileRegistry fileRegistry) {
+ private void createGlobalRankProfiles(DeployState deployState) {
var importedModels = deployState.getImportedModels().all();
DeployLogger deployLogger = deployState.getDeployLogger();
RankProfileRegistry rankProfileRegistry = deployState.rankProfileRegistry();
@@ -283,8 +281,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
if ( ! importedModels.isEmpty()) { // models/ directory is available
for (ImportedMlModel model : importedModels) {
// Due to automatic naming not guaranteeing unique names, there must be a 1-1 between OnnxModels and global RankProfiles.
- OnnxModels onnxModels = onnxModelInfoFromSource(model, fileRegistry);
- RankProfile profile = new RankProfile(model.name(), applicationPackage, deployLogger, rankProfileRegistry, onnxModels);
+ RankProfile profile = new RankProfile(model.name(), applicationPackage, deployLogger, rankProfileRegistry);
+ addOnnxModelInfoFromSource(model, profile);
rankProfileRegistry.add(profile);
futureModels.add(deployState.getExecutor().submit(() -> {
ConvertedModel convertedModel = ConvertedModel.fromSource(applicationPackage, new ModelName(model.name()),
@@ -300,8 +298,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
String modelName = generatedModelDir.getPath().last();
if (modelName.contains(".")) continue; // Name space: Not a global profile
// Due to automatic naming not guaranteeing unique names, there must be a 1-1 between OnnxModels and global RankProfiles.
- OnnxModels onnxModels = onnxModelInfoFromStore(modelName, fileRegistry);
- RankProfile profile = new RankProfile(modelName, applicationPackage, deployLogger, rankProfileRegistry, onnxModels);
+ RankProfile profile = new RankProfile(modelName, applicationPackage, deployLogger, rankProfileRegistry);
+ addOnnxModelInfoFromStore(modelName, profile);
rankProfileRegistry.add(profile);
futureModels.add(deployState.getExecutor().submit(() -> {
ConvertedModel convertedModel = ConvertedModel.fromStore(applicationPackage, new ModelName(modelName), modelName, profile);
@@ -320,27 +318,23 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
new Processing().processRankProfiles(deployLogger, rankProfileRegistry, queryProfiles, true, false);
}
- private OnnxModels onnxModelInfoFromSource(ImportedMlModel model, FileRegistry fileRegistry) {
- OnnxModels onnxModels = new OnnxModels(fileRegistry, Optional.empty());
+ private void addOnnxModelInfoFromSource(ImportedMlModel model, RankProfile profile) {
if (model.modelType() == ImportedMlModel.ModelType.ONNX) {
String path = model.source();
String applicationPath = this.applicationPackage.getFileReference(Path.fromString("")).toString();
if (path.startsWith(applicationPath)) {
path = path.substring(applicationPath.length() + 1);
}
- loadOnnxModelInfo(onnxModels, model.name(), path);
+ addOnnxModelInfo(model.name(), path, profile);
}
- return onnxModels;
}
- private OnnxModels onnxModelInfoFromStore(String modelName, FileRegistry fileRegistry) {
+ private void addOnnxModelInfoFromStore(String modelName, RankProfile profile) {
String path = ApplicationPackage.MODELS_DIR.append(modelName + ".onnx").toString();
- OnnxModels onnxModels = new OnnxModels(fileRegistry, Optional.empty());
- loadOnnxModelInfo(onnxModels, modelName, path);
- return onnxModels;
+ addOnnxModelInfo(modelName, path, profile);
}
- private void loadOnnxModelInfo(OnnxModels onnxModels, String name, String path) {
+ private void addOnnxModelInfo(String name, String path, RankProfile profile) {
boolean modelExists = OnnxModelInfo.modelExists(path, this.applicationPackage);
if ( ! modelExists) {
path = ApplicationPackage.MODELS_DIR.append(path).toString();
@@ -351,7 +345,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
if (onnxModelInfo.getModelPath() != null) {
OnnxModel onnxModel = new OnnxModel(name, onnxModelInfo.getModelPath());
onnxModel.setModelInfo(onnxModelInfo);
- onnxModels.add(onnxModel);
+ profile.add(onnxModel);
}
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantValidator.java
index b033d7a1e3b..fb70f2b769c 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantValidator.java
@@ -25,7 +25,7 @@ public class ConstantValidator extends Validator {
public void validate(VespaModel model, DeployState deployState) {
var exceptionMessageCollector = new ExceptionMessageCollector("Invalid constant tensor file(s):");
for (Schema schema : deployState.getSchemas()) {
- for (var constant : schema.constants().values())
+ for (var constant : schema.declaredConstants().values())
validate(constant, deployState.getApplicationPackage(), exceptionMessageCollector);
for (var profile : deployState.rankProfileRegistry().rankProfilesOf(schema)) {
for (var constant : profile.declaredConstants().values())
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
index b9b7f122d63..3492ccf0b21 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
@@ -152,7 +152,7 @@ public class RankSetupValidator extends Validator {
List<String> config = new ArrayList<>();
// Assist verify-ranksetup in finding the actual ONNX model files
- writeExtraVerifyRankSetupConfig(config, db.getDerivedConfiguration().getSchema().onnxModels().asMap().values());
+ writeExtraVerifyRankSetupConfig(config, db.getDerivedConfiguration().getRankProfileList().getOnnxModels().asMap().values());
writeExtraVerifyRankSetupConfig(config, db.getDerivedConfiguration().getSchema().rankExpressionFiles().expressions());
config.sort(String::compareTo);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
index bb84f809fc4..17640417b3f 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
@@ -598,7 +598,7 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> {
Element onnxElement = XML.getChild(modelEvaluationElement, "onnx");
Element modelsElement = XML.getChild(onnxElement, "models");
for (Element modelElement : XML.getChildren(modelsElement, "model") ) {
- OnnxModel onnxModel = profiles.getOnnxModels().get(modelElement.getAttribute("name"));
+ OnnxModel onnxModel = profiles.getOnnxModels().asMap().get(modelElement.getAttribute("name"));
if (onnxModel == null)
continue; // Skip if model is not found
onnxModel.setStatelessExecutionMode(getStringValue(modelElement, "execution-mode", null));
diff --git a/config-model/src/main/javacc/IntermediateParser.jj b/config-model/src/main/javacc/IntermediateParser.jj
index 873196d8bda..01f111df284 100644
--- a/config-model/src/main/javacc/IntermediateParser.jj
+++ b/config-model/src/main/javacc/IntermediateParser.jj
@@ -427,7 +427,7 @@ void rootSchemaItem(ParsedSchema schema) : { }
| structOutside(schema)
| annotationOutside(schema)
| fieldSet(schema)
- | onnxModel(schema)
+ | onnxModelInSchema(schema) // Deprecated: TODO: Emit warning when on Vespa 8
)
}
@@ -1703,31 +1703,38 @@ void hnswIndexBody(HnswIndexParams.Builder params) :
| <MULTITHREADEDINDEXING> <COLON> bool = bool() { params.setMultiThreadedIndexing(bool); } )
}
-/**
- * Consumes a onnx-model block of a schema element.
- *
- * @param schema the schema object to add content to.
- */
-void onnxModel(ParsedSchema schema) :
+void onnxModelInSchema(ParsedSchema schema) :
+{
+ OnnxModel onnxModel;
+}
+{
+ onnxModel = onnxModel() { schema.add(onnxModel); }
+}
+
+void onnxModelInProfile(ParsedRankProfile profile) :
+{
+ OnnxModel onnxModel;
+}
+{
+ onnxModel = onnxModel() { profile.add(onnxModel); }
+}
+
+/** Consumes an onnx-model block. */
+OnnxModel onnxModel() :
{
String name;
OnnxModel onnxModel;
}
{
- ( <ONNXMODEL> name = identifier()
- {
- onnxModel = new OnnxModel(name);
- }
+ ( <ONNXMODEL> name = identifier() { onnxModel = new OnnxModel(name); }
lbrace() (onnxModelItem(onnxModel) (<NL>)*)+ <RBRACE> )
- {
- schema.addOnnxModel(onnxModel);
- }
+ { return onnxModel; }
}
/**
- * This rule consumes an onnx-model block.
+ * Consumes an onnx-model block.
*
- * @param onnxModel The onnxModel to modify.
+ * @param onnxModel the onnxModel to modify
*/
void onnxModelItem(OnnxModel onnxModel) :
{
@@ -1849,6 +1856,7 @@ void rankProfileItem(ParsedSchema schema, ParsedRankProfile profile) : { }
| constants(schema, profile)
| matchFeatures(profile)
| summaryFeatures(profile)
+ | onnxModelInProfile(profile)
| strict(profile) )
}
diff --git a/config-model/src/test/integration/onnx-model/schemas/test.sd b/config-model/src/test/integration/onnx-model/schemas/test.sd
index a15714767ba..82872758dd9 100644
--- a/config-model/src/test/integration/onnx-model/schemas/test.sd
+++ b/config-model/src/test/integration/onnx-model/schemas/test.sd
@@ -21,14 +21,6 @@ search test {
output "path/to/output:0": out
}
- onnx-model another_model {
- file: files/model.onnx
- input first_input: attribute(document_field)
- input "second/input:0": constant(my_constant)
- input "third_input": another_function
- output "path/to/output:2": out
- }
-
onnx-model dynamic_model {
file: files/dynamic_model.onnx
input input: my_function
@@ -72,6 +64,13 @@ search test {
first-phase {
expression: 1
}
+ onnx-model another_model {
+ file: files/model.onnx
+ input first_input: attribute(document_field)
+ input "second/input:0": constant(my_constant)
+ input "third_input": another_function
+ output "path/to/output:2": out
+ }
summary-features {
onnx(another_model).out
onnx("files/summary_model.onnx", "path/to/output:2")
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/SchemaTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/SchemaTestCase.java
index 65f4dab3650..cb57746d82f 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/SchemaTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/SchemaTestCase.java
@@ -196,8 +196,8 @@ public class SchemaTestCase {
assertTrue(child1profile.constants().containsKey(FeatureNames.asConstantFeature("parent_constant")));
assertNotNull(child1.onnxModels().get("parent_model"));
assertNotNull(child1.onnxModels().get("child1_model"));
- assertTrue(child1.onnxModels().asMap().containsKey("parent_model"));
- assertTrue(child1.onnxModels().asMap().containsKey("child1_model"));
+ assertTrue(child1.onnxModels().containsKey("parent_model"));
+ assertTrue(child1.onnxModels().containsKey("child1_model"));
assertNotNull(child1.getSummary("parent_summary"));
assertNotNull(child1.getSummary("child1_summary"));
assertEquals("parent_summary", child1.getSummary("child1_summary").inherited().get().getName());
@@ -231,8 +231,8 @@ public class SchemaTestCase {
assertTrue(child2.constants().containsKey(FeatureNames.asConstantFeature("child2_constant")));
assertNotNull(child2.onnxModels().get("parent_model"));
assertNotNull(child2.onnxModels().get("child2_model"));
- assertTrue(child2.onnxModels().asMap().containsKey("parent_model"));
- assertTrue(child2.onnxModels().asMap().containsKey("child2_model"));
+ assertTrue(child2.onnxModels().containsKey("parent_model"));
+ assertTrue(child2.onnxModels().containsKey("child2_model"));
assertNotNull(child2.getSummary("parent_summary"));
assertNotNull(child2.getSummary("child2_summary"));
assertEquals("parent_summary", child2.getSummary("child2_summary").inherited().get().getName());
@@ -430,7 +430,7 @@ public class SchemaTestCase {
assertNotNull(schema.constants().get(FeatureNames.asConstantFeature("parent_constant")));
assertTrue(schema.constants().containsKey(FeatureNames.asConstantFeature("parent_constant")));
assertNotNull(schema.onnxModels().get("parent_model"));
- assertTrue(schema.onnxModels().asMap().containsKey("parent_model"));
+ assertTrue(schema.onnxModels().containsKey("parent_model"));
assertNotNull(schema.getSummary("parent_summary"));
assertTrue(schema.getSummaries().containsKey("parent_summary"));
assertNotNull(schema.getSummaryField("pf1"));
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/GeminiTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/GeminiTestCase.java
index 07e6fbf7b1b..207792ffe06 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/GeminiTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/GeminiTestCase.java
@@ -19,7 +19,7 @@ public class GeminiTestCase extends AbstractExportingTestCase {
@Test
public void testRanking2() throws IOException, ParseException {
DerivedConfiguration c = assertCorrectDeriving("gemini2");
- RawRankProfile p = c.getRankProfileList().getRankProfile("test");
+ RawRankProfile p = c.getRankProfileList().getRankProfiles().get("test");
Map<String, String> ranking = removePartKeySuffixes(asMap(p.configProperties()));
assertEquals("attribute(right)", resolve(lookup("toplevel", ranking), ranking));
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
index 1c23950d972..6820a8d9678 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
@@ -93,6 +93,18 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals("path_to_output_2", model.output(2).as());
model = config.model(1);
+ assertEquals("dynamic_model", model.name());
+ assertEquals(1, model.input().size());
+ assertEquals(1, model.output().size());
+ assertEquals("rankingExpression(my_function)", model.input(0).source());
+
+ model = config.model(2);
+ assertEquals("unbound_model", model.name());
+ assertEquals(1, model.input().size());
+ assertEquals(1, model.output().size());
+ assertEquals("rankingExpression(my_function)", model.input(0).source());
+
+ model = config.model(3);
assertEquals("files_model_onnx", model.name());
assertEquals(3, model.input().size());
assertEquals(3, model.output().size());
@@ -104,27 +116,15 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals("path_to_output_2", model.output(2).as());
assertEquals("files_model_onnx", model.name());
- model = config.model(2);
+ model = config.model(4);
assertEquals("another_model", model.name());
assertEquals("third_input", model.input(2).name());
assertEquals("rankingExpression(another_function)", model.input(2).source());
- model = config.model(3);
+ model = config.model(5);
assertEquals("files_summary_model_onnx", model.name());
assertEquals(3, model.input().size());
assertEquals(3, model.output().size());
-
- model = config.model(4);
- assertEquals("unbound_model", model.name());
- assertEquals(1, model.input().size());
- assertEquals(1, model.output().size());
- assertEquals("rankingExpression(my_function)", model.input(0).source());
-
- model = config.model(5);
- assertEquals("dynamic_model", model.name());
- assertEquals(1, model.input().size());
- assertEquals(1, model.output().size());
- assertEquals("rankingExpression(my_function)", model.input(0).source());
}
private void assertTransformedFeature(VespaModel model) {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java
index e0427d93ee4..4446f01aa95 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/VespaMlModelTestCase.java
@@ -68,7 +68,7 @@ public class VespaMlModelTestCase {
private String rankConfigOf(String rankProfileName, VespaModel model) {
StringBuilder b = new StringBuilder();
- RawRankProfile profile = model.rankProfileList().getRankProfile(rankProfileName);
+ RawRankProfile profile = model.rankProfileList().getRankProfiles().get(rankProfileName);
for (var property : profile.configProperties())
b.append(property.getFirst()).append(" : ").append(property.getSecond()).append("\n");
return b.toString();