summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/derived
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/derived')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java3
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedConstants.java11
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java63
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java149
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java4
5 files changed, 138 insertions, 92 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
index 50029613b2e..c7b9b94a4b2 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
@@ -82,8 +82,7 @@ public class DerivedConfiguration implements AttributesConfig.Producer {
summaries = new Summaries(schema, deployState.getDeployLogger(), deployState.getProperties().featureFlags());
summaryMap = new SummaryMap(schema);
juniperrc = new Juniperrc(schema);
- rankProfileList = new RankProfileList(schema, schema.constants(), schema.rankExpressionFiles(),
- schema.onnxModels(), attributeFields, deployState);
+ rankProfileList = new RankProfileList(schema, schema.rankExpressionFiles(), attributeFields, deployState);
indexingScript = new IndexingScript(schema);
indexInfo = new IndexInfo(schema);
schemaInfo = new SchemaInfo(schema, deployState.rankProfileRegistry(), summaries, summaryMap);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedConstants.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedConstants.java
index 8de86beacdb..433bfb108d6 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedConstants.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedConstants.java
@@ -5,12 +5,12 @@ import com.yahoo.config.application.api.FileRegistry;
import com.yahoo.searchdefinition.DistributableResource;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.tensor.TensorType;
+import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
-import java.util.function.Function;
/**
* Constant values for ranking/model execution tied to a rank profile,
@@ -41,6 +41,15 @@ public class FileDistributedConstants {
/** Returns a read-only map of the constants in this indexed by name. */
public Map<String, DistributableConstant> asMap() { return constants; }
+ public void getConfig(RankingConstantsConfig.Builder builder) {
+ for (var constant : constants.values()) {
+ builder.constant(new RankingConstantsConfig.Constant.Builder()
+ .name(constant.getName())
+ .fileref(constant.getFileReference())
+ .type(constant.getType()));
+ }
+ }
+
public static class DistributableConstant extends DistributableResource {
private final TensorType tensorType;
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java
new file mode 100644
index 00000000000..a1310d6c18e
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/FileDistributedOnnxModels.java
@@ -0,0 +1,63 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.derived;
+
+import com.yahoo.config.application.api.FileRegistry;
+import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchdefinition.Schema;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Optional;
+import java.util.logging.Logger;
+
+/**
+ * ONNX models distributed as files.
+ *
+ * @author bratseth
+ */
+public class FileDistributedOnnxModels {
+
+ private static final Logger log = Logger.getLogger(FileDistributedOnnxModels.class.getName());
+
+ private final Map<String, OnnxModel> models;
+
+ public FileDistributedOnnxModels(FileRegistry fileRegistry, Collection<OnnxModel> models) {
+ Map<String, OnnxModel> distributableModels = new LinkedHashMap<>();
+ for (var model : models) {
+ model.validate();
+ model.register(fileRegistry);
+ distributableModels.put(model.getName(), model);
+ }
+ this.models = Collections.unmodifiableMap(distributableModels);
+ }
+
+ public Map<String, OnnxModel> asMap() { return models; }
+
+ public void getConfig(OnnxModelsConfig.Builder builder) {
+ for (OnnxModel model : models.values()) {
+ if ("".equals(model.getFileReference()))
+ log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way
+ else {
+ OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder();
+ modelBuilder.dry_run_on_setup(true);
+ modelBuilder.name(model.getName());
+ modelBuilder.fileref(model.getFileReference());
+ model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
+ model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
+ if (model.getStatelessExecutionMode().isPresent())
+ modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get());
+ if (model.getStatelessInterOpThreads().isPresent())
+ modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get());
+ if (model.getStatelessIntraOpThreads().isPresent())
+ modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get());
+
+ builder.model(modelBuilder);
+ }
+ }
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index d415f7d0d6e..081450275d1 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -2,13 +2,11 @@
package com.yahoo.searchdefinition.derived;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels;
-import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.model.api.ModelContext;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.search.query.profile.QueryProfileRegistry;
-import com.yahoo.searchdefinition.OnnxModel;
-import com.yahoo.searchdefinition.OnnxModels;
import com.yahoo.searchdefinition.LargeRankExpressions;
+import com.yahoo.searchdefinition.OnnxModel;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.vespa.config.search.RankProfilesConfig;
@@ -24,13 +22,10 @@ import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
-import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
-import java.util.logging.Level;
-import java.util.logging.Logger;
/**
* The derived rank profiles of a schema
@@ -39,19 +34,17 @@ import java.util.logging.Logger;
*/
public class RankProfileList extends Derived implements RankProfilesConfig.Producer {
- private static final Logger log = Logger.getLogger(RankProfileList.class.getName());
-
private final Map<String, RawRankProfile> rankProfiles;
private final FileDistributedConstants constants;
private final LargeRankExpressions largeRankExpressions;
- private final OnnxModels onnxModels;
+ private final FileDistributedOnnxModels onnxModels;
public static final RankProfileList empty = new RankProfileList();
private RankProfileList() {
constants = new FileDistributedConstants(null, List.of());
largeRankExpressions = new LargeRankExpressions(null);
- onnxModels = new OnnxModels(null, Optional.empty());
+ onnxModels = new FileDistributedOnnxModels(null, List.of());
rankProfiles = Map.of();
}
@@ -62,46 +55,16 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
* @param attributeFields the attribute fields to create a ranking for
*/
public RankProfileList(Schema schema,
- Map<Reference, RankProfile.Constant> constantsFromSchema,
LargeRankExpressions largeRankExpressions,
- OnnxModels onnxModels,
AttributeFields attributeFields,
DeployState deployState) {
setName(schema == null ? "default" : schema.getName());
this.largeRankExpressions = largeRankExpressions;
- this.onnxModels = onnxModels; // as ONNX models come from parsing rank expressions
this.rankProfiles = deriveRankProfiles(schema, attributeFields, deployState);
- this.constants = deriveFileDistributedConstants(schema, constantsFromSchema, rankProfiles.values(), deployState);
- }
-
- private static FileDistributedConstants deriveFileDistributedConstants(Schema schema,
- Map<Reference, RankProfile.Constant> constantsFromSchema,
- Collection<RawRankProfile> rankProfiles,
- DeployState deployState) {
- Map<Reference, RankProfile.Constant> allFileConstants = new HashMap<>();
- addFileConstants(constantsFromSchema.values(), allFileConstants, schema != null ? schema.toString() : "[global]");
- for (var profile : rankProfiles)
- addFileConstants(profile.compiled().constants().values(), allFileConstants, profile.toString());
- return new FileDistributedConstants(deployState.getFileRegistry(), allFileConstants.values());
- }
-
- private static void addFileConstants(Collection<RankProfile.Constant> source,
- Map<Reference, RankProfile.Constant> destination,
- String sourceName) {
- for (var constant : source) {
- if (constant.valuePath().isEmpty()) continue;
- var existing = destination.get(constant.name());
- if ( existing != null && ! constant.equals(existing)) {
- throw new IllegalArgumentException("Duplicate constants: " + sourceName + " have " + constant +
- ", but we already have " + existing +
- ": Value reference constants must be unique across all rank profiles/models");
- }
- destination.put(constant.name(), constant);
- }
+ this.constants = deriveFileDistributedConstants(schema, rankProfiles.values(), deployState);
+ this.onnxModels = deriveFileDistributedOnnxModels(schema, rankProfiles.values(), deployState);
}
- public FileDistributedConstants constants() { return constants; }
-
private boolean areDependenciesReady(RankProfile rank, RankProfileRegistry registry, Set<String> processedProfiles) {
return rank.inheritedNames().isEmpty() ||
processedProfiles.containsAll(rank.inheritedNames()) ||
@@ -134,7 +97,6 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
rawRankProfiles.putAll(processRankProfiles(ready,
deployState.getQueryProfiles().getRegistry(),
deployState.getImportedModels(),
- schema,
attributeFields,
deployState.getProperties(),
deployState.getExecutor()));
@@ -143,20 +105,15 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
return rawRankProfiles;
}
- private Map<String, RawRankProfile> processRankProfiles(List<RankProfile> ready,
+ private Map<String, RawRankProfile> processRankProfiles(List<RankProfile> profiles,
QueryProfileRegistry queryProfiles,
ImportedMlModels importedModels,
- Schema schema,
AttributeFields attributeFields,
ModelContext.Properties deployProperties,
ExecutorService executor) {
Map<String, Future<RawRankProfile>> futureRawRankProfiles = new LinkedHashMap<>();
- for (RankProfile rank : ready) {
- if (schema == null) {
- onnxModels.add(rank.onnxModels());
- }
-
- futureRawRankProfiles.put(rank.name(), executor.submit(() -> new RawRankProfile(rank, largeRankExpressions, queryProfiles, importedModels,
+ for (RankProfile profile : profiles) {
+ futureRawRankProfiles.put(profile.name(), executor.submit(() -> new RawRankProfile(profile, largeRankExpressions, queryProfiles, importedModels,
attributeFields, deployProperties)));
}
try {
@@ -171,19 +128,63 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
}
}
- public OnnxModels getOnnxModels() {
- return onnxModels;
+ private static FileDistributedConstants deriveFileDistributedConstants(Schema schema,
+ Collection<RawRankProfile> rankProfiles,
+ DeployState deployState) {
+ Map<Reference, RankProfile.Constant> allFileConstants = new HashMap<>();
+ addFileConstants(schema != null ? schema.constants().values() : List.of(),
+ allFileConstants,
+ schema != null ? schema.toString() : "[global]");
+ for (var profile : rankProfiles)
+ addFileConstants(profile.compiled().constants().values(), allFileConstants, profile.toString());
+ return new FileDistributedConstants(deployState.getFileRegistry(), allFileConstants.values());
}
- public Map<String, RawRankProfile> getRankProfiles() {
- return rankProfiles;
+ private static void addFileConstants(Collection<RankProfile.Constant> source,
+ Map<Reference, RankProfile.Constant> destination,
+ String sourceName) {
+ for (var constant : source) {
+ if (constant.valuePath().isEmpty()) continue;
+ var existing = destination.get(constant.name());
+ if ( existing != null && ! constant.equals(existing)) {
+ throw new IllegalArgumentException("Duplicate constants: " + sourceName + " have " + constant +
+ ", but we already have " + existing +
+ ": Value reference constants must be unique across all rank profiles/models");
+ }
+ destination.put(constant.name(), constant);
+ }
}
- /** Returns the raw rank profile with the given name, or null if it is not present */
- public RawRankProfile getRankProfile(String name) {
- return rankProfiles.get(name);
+ private static FileDistributedOnnxModels deriveFileDistributedOnnxModels(Schema schema,
+ Collection<RawRankProfile> rankProfiles,
+ DeployState deployState) {
+ Map<String, OnnxModel> allModels = new LinkedHashMap<>();
+ addOnnxModels(schema != null ? schema.onnxModels().values() : List.of(),
+ allModels,
+ schema != null ? schema.toString() : "[global]");
+ for (var profile : rankProfiles)
+ addOnnxModels(profile.compiled().onnxModels().values(), allModels, profile.toString());
+ return new FileDistributedOnnxModels(deployState.getFileRegistry(), allModels.values());
}
+ private static void addOnnxModels(Collection<OnnxModel> source,
+ Map<String, OnnxModel> destination,
+ String sourceName) {
+ for (var model : source) {
+ var existing = destination.get(model.getName());
+ if ( existing != null && ! model.equals(existing)) {
+ throw new IllegalArgumentException("Duplicate onnx model: " + sourceName + " have " + model +
+ ", but we already have " + existing +
+ ": Onnx models must be unique across all rank profiles/models");
+ }
+ destination.put(model.getName(), model);
+ }
+ }
+
+ public Map<String, RawRankProfile> getRankProfiles() { return rankProfiles; }
+ public FileDistributedConstants constants() { return constants; }
+ public FileDistributedOnnxModels getOnnxModels() { return onnxModels; }
+
@Override
public String getDerivedName() { return "rank-profiles"; }
@@ -199,37 +200,11 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
}
public void getConfig(RankingConstantsConfig.Builder builder) {
- for (var constant : constants.asMap().values()) {
- if ("".equals(constant.getFileReference()))
- log.warning("Illegal file reference " + constant); // Let tests pass ... we should find a better way
- else
- builder.constant(new RankingConstantsConfig.Constant.Builder()
- .name(constant.getName())
- .fileref(constant.getFileReference())
- .type(constant.getType()));
- }
+ constants.getConfig(builder);
}
public void getConfig(OnnxModelsConfig.Builder builder) {
- for (OnnxModel model : onnxModels.asMap().values()) {
- if ("".equals(model.getFileReference()))
- log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way
- else {
- OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder();
- modelBuilder.dry_run_on_setup(true);
- modelBuilder.name(model.getName());
- modelBuilder.fileref(model.getFileReference());
- model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
- model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
- if (model.getStatelessExecutionMode().isPresent())
- modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get());
- if (model.getStatelessInterOpThreads().isPresent())
- modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get());
- if (model.getStatelessIntraOpThreads().isPresent())
- modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get());
-
- builder.model(modelBuilder);
- }
- }
+ onnxModels.getConfig(builder);
}
+
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index 2189d9d97db..dcd5019cf58 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -483,7 +483,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
private void deriveOnnxModelFunctionsAndFeatures(RankProfile rankProfile) {
if (rankProfile.schema() == null) return;
- if (rankProfile.schema().onnxModels().asMap().isEmpty()) return;
+ if (rankProfile.onnxModels().isEmpty()) return;
replaceOnnxFunctionInputs(rankProfile);
replaceImplicitOnnxConfigFeatures(summaryFeatures, rankProfile);
replaceImplicitOnnxConfigFeatures(matchFeatures, rankProfile);
@@ -492,7 +492,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
private void replaceOnnxFunctionInputs(RankProfile rankProfile) {
Set<String> functionNames = rankProfile.getFunctions().keySet();
if (functionNames.isEmpty()) return;
- for (OnnxModel onnxModel: rankProfile.schema().onnxModels().asMap().values()) {
+ for (OnnxModel onnxModel: rankProfile.onnxModels().values()) {
for (Map.Entry<String, String> mapping : onnxModel.getInputMap().entrySet()) {
String source = mapping.getValue();
if (functionNames.contains(source)) {