aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-09-02 15:33:12 +0200
committerLester Solbakken <lesters@oath.com>2020-09-02 15:33:12 +0200
commitc80e560d472bb37643b0eadac9e4915b16d11f3d (patch)
tree418f7faa37ac08f3242a872c8cc618279a9e8fec /config-model/src/main/java/com
parentf7551282dae23a3d079c406bd15b23025bbe0f68 (diff)
Add config generation for models evaluated via ONNXRT
Diffstat (limited to 'config-model/src/main/java/com')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSearch.java1
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java79
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java39
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/Search.java6
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java21
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java1
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java78
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java33
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java7
10 files changed, 267 insertions, 2 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSearch.java b/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSearch.java
index 0b9447d05f5..6ac73ad45a9 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSearch.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSearch.java
@@ -28,6 +28,7 @@ public interface ImmutableSearch {
Reader getRankingExpression(String fileName);
ApplicationPackage applicationPackage();
RankingConstants rankingConstants();
+ OnnxModels onnxModels();
Stream<ImmutableSDField> allImportedFields();
ImmutableSDField getField(String name);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
new file mode 100644
index 00000000000..b7b18887dd8
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
@@ -0,0 +1,79 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition;
+
+import com.yahoo.config.FileReference;
+import com.yahoo.vespa.model.AbstractService;
+import com.yahoo.vespa.model.utils.FileSender;
+
+import java.util.Collection;
+import java.util.Objects;
+
+/**
+ * A global ONNX model distributed using file distribution, similar to ranking constants.
+ *
+ * @author lesters
+ */
+public class OnnxModel {
+
+ public enum PathType {FILE, URI};
+
+ private final String name;
+ private String path = null;
+ private String fileReference = "";
+
+ public PathType getPathType() {
+ return pathType;
+ }
+
+ private PathType pathType = PathType.FILE;
+
+ public OnnxModel(String name) {
+ this.name = name;
+ }
+
+ public OnnxModel(String name, String fileName) {
+ this(name);
+ this.path = fileName;
+ validate();
+ }
+
+ public void setFileName(String fileName) {
+ Objects.requireNonNull(fileName, "Filename cannot be null");
+ this.path = fileName;
+ this.pathType = PathType.FILE;
+ }
+
+ public void setUri(String uri) {
+ Objects.requireNonNull(uri, "uri cannot be null");
+ this.path = uri;
+ this.pathType = PathType.URI;
+ }
+
+ /** Initiate sending of this constant to some services over file distribution */
+ public void sendTo(Collection<? extends AbstractService> services) {
+ FileReference reference = (pathType == OnnxModel.PathType.FILE)
+ ? FileSender.sendFileToServices(path, services)
+ : FileSender.sendUriToServices(path, services);
+ this.fileReference = reference.value();
+ }
+
+ public String getName() { return name; }
+ public String getFileName() { return path; }
+ public String getUri() { return path; }
+ public String getFileReference() { return fileReference; }
+
+ public void validate() {
+ if (path == null || path.isEmpty())
+ throw new IllegalArgumentException("ONNX models must have a file or uri.");
+ }
+
+ public String toString() {
+ StringBuilder b = new StringBuilder();
+ b.append("onnx-model '").append(name)
+ .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path)
+ .append("' with ref '").append(fileReference)
+ .append("'");
+ return b.toString();
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
new file mode 100644
index 00000000000..87663ac79a3
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
@@ -0,0 +1,39 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition;
+
+import com.yahoo.vespa.model.AbstractService;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * ONNX models tied to a search definition.
+ *
+ * @author lesters
+ */
+public class OnnxModels {
+
+ private final Map<String, OnnxModel> models = new HashMap<>();
+
+ public void add(OnnxModel model) {
+ model.validate();
+ String name = model.getName();
+ models.put(name, model);
+ }
+
+ public OnnxModel get(String name) {
+ return models.get(name);
+ }
+
+ public Map<String, OnnxModel> asMap() {
+ return Collections.unmodifiableMap(models);
+ }
+
+ /** Initiate sending of these models to some services over file distribution */
+ public void sendTo(Collection<? extends AbstractService> services) {
+ models.values().forEach(model -> model.sendTo(services));
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java
index 3a60a75f75f..64c5590b689 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java
@@ -79,6 +79,9 @@ public class Search implements ImmutableSearch {
/** Ranking constants of this */
private final RankingConstants rankingConstants = new RankingConstants();
+ /** Onnx models of this */
+ private final OnnxModels onnxModels = new OnnxModels();
+
private Optional<TemporaryImportedFields> temporaryImportedFields = Optional.of(new TemporaryImportedFields());
private Optional<ImportedFields> importedFields = Optional.empty();
@@ -159,6 +162,9 @@ public class Search implements ImmutableSearch {
@Override
public RankingConstants rankingConstants() { return rankingConstants; }
+ @Override
+ public OnnxModels onnxModels() { return onnxModels; }
+
public Optional<TemporaryImportedFields> temporaryImportedFields() {
return temporaryImportedFields;
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index f6d4889b55e..00076c84532 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -4,12 +4,15 @@ package com.yahoo.searchdefinition.derived;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels;
import com.yahoo.config.model.api.ModelContext;
import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchdefinition.OnnxModels;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchdefinition.RankingConstants;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.Search;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.model.AbstractService;
@@ -22,17 +25,21 @@ import java.util.logging.Logger;
*
* @author bratseth
*/
-public class RankProfileList extends Derived implements RankProfilesConfig.Producer, RankingConstantsConfig.Producer {
+public class RankProfileList extends Derived implements RankProfilesConfig.Producer,
+ RankingConstantsConfig.Producer,
+ OnnxModelsConfig.Producer {
private static final Logger log = Logger.getLogger(RankProfileList.class.getName());
private final Map<String, RawRankProfile> rankProfiles = new java.util.LinkedHashMap<>();
private final RankingConstants rankingConstants;
+ private final OnnxModels onnxModels;
public static RankProfileList empty = new RankProfileList();
private RankProfileList() {
this.rankingConstants = new RankingConstants();
+ this.onnxModels = new OnnxModels();
}
/**
@@ -51,6 +58,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
setName(search == null ? "default" : search.getName());
this.rankingConstants = rankingConstants;
deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields, deployProperties);
+ this.onnxModels = search == null ? new OnnxModels() : search.onnxModels(); // as ONNX models come from parsing rank expressions
}
private void deriveRankProfiles(RankProfileRegistry rankProfileRegistry,
@@ -109,4 +117,15 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
}
}
+ @Override
+ public void getConfig(OnnxModelsConfig.Builder builder) {
+ for (OnnxModel model : onnxModels.asMap().values()) {
+ if ("".equals(model.getFileReference()))
+ log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way
+ else
+ builder.model(new OnnxModelsConfig.Model.Builder()
+ .name(model.getName())
+ .fileref(model.getFileReference()));
+ }
+ }
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java
index a6707ec7ac0..a723be8b478 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java
@@ -26,6 +26,7 @@ public class ExpressionTransforms {
transforms =
ImmutableList.of(new TensorFlowFeatureConverter(),
new OnnxFeatureConverter(),
+ new OnnxModelTransformer(),
new XgboostFeatureConverter(),
new LightGBMFeatureConverter(),
new ConstantDereferencer(),
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
new file mode 100644
index 00000000000..d8ffbd7d030
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
@@ -0,0 +1,78 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.expressiontransforms;
+
+import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+
+import java.util.List;
+
+/**
+ * Transforms instances of the onnxModel(model-path, output) ranking feature
+ * by adding the model file to file distribution and rewriting this feature
+ * to point to the generated configuration.
+ *
+ * @author lesters
+ */
+public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTransformContext> {
+
+ @Override
+ public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
+ if (node instanceof ReferenceNode)
+ return transformFeature((ReferenceNode) node, context);
+ else if (node instanceof CompositeNode)
+ return super.transformChildren((CompositeNode) node, context);
+ else
+ return node;
+ }
+
+ private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
+ if (!feature.getName().equals("onnxModel")) return feature;
+
+ Arguments arguments = feature.getArguments();
+ if (arguments.isEmpty())
+ throw new IllegalArgumentException("An onnxModel feature must take an argument pointing to the ONNX file.");
+ if (arguments.expressions().size() > 2)
+ throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments.");
+
+ String path = asString(arguments.expressions().get(0));
+ String name = toModelName(path);
+ String output = arguments.expressions().size() > 1 ? asString(arguments.expressions().get(1)) : null;
+
+ // Validation that the file actually exists is handled when the file is added to file distribution.
+ // Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator.
+
+ // Add model to config
+ context.rankProfile().getSearch().onnxModels().add(new OnnxModel(name, path));
+
+ // Replace feature with name of config
+ ExpressionNode argument = new ReferenceNode(name);
+ return new ReferenceNode("onnxModel", List.of(argument), output);
+ }
+
+ private static String asString(ExpressionNode node) {
+ if ( ! (node instanceof ConstantNode))
+ throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
+ return stripQuotes(((ConstantNode)node).sourceString());
+ }
+
+ private static String stripQuotes(String s) {
+ if ( ! isQuoteSign(s.codePointAt(0))) return s;
+ if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
+ throw new IllegalArgumentException("argument [" + s + "] is missing endquote");
+ return s.substring(1, s.length()-1);
+ }
+
+ private static boolean isQuoteSign(int c) {
+ return c == '\'' || c == '"';
+ }
+
+ public static String toModelName(String path) {
+ return path.replaceAll("[^\\w\\d\\$@_]", "_");
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
index 96dc312abd7..c6c7969e466 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
@@ -1,21 +1,25 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.application.validation;
+import com.yahoo.cloud.config.ConfigserverConfig;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.io.IOUtils;
import com.yahoo.log.InvalidLogFormatException;
import java.util.logging.Level;
import com.yahoo.log.LogMessage;
+import com.yahoo.searchdefinition.OnnxModel;
import com.yahoo.yolean.Exceptions;
import com.yahoo.system.ProcessExecuter;
import com.yahoo.text.StringUtilities;
import com.yahoo.vespa.config.search.AttributesConfig;
import com.yahoo.collections.Pair;
import com.yahoo.config.ConfigInstance;
+import com.yahoo.vespa.defaults.Defaults;
import com.yahoo.vespa.config.search.ImportedFieldsConfig;
import com.yahoo.vespa.config.search.IndexschemaConfig;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.model.producer.AbstractConfigProducer;
import com.yahoo.vespa.model.VespaModel;
@@ -27,9 +31,13 @@ import com.yahoo.vespa.model.search.SearchCluster;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
+import java.nio.file.Paths;
import java.time.Duration;
import java.time.Instant;
import java.util.logging.Logger;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
/**
* Validate rank setup for all search clusters (rank-profiles, index-schema, attributes configs), validating done
@@ -66,6 +74,7 @@ public class RankSetupValidator extends Validator {
final String name = docDb.getDerivedConfiguration().getSearch().getName();
String searchDir = clusterDir + name + "/";
writeConfigs(searchDir, docDb);
+ writeExtraVerifyRanksetupConfig(searchDir, docDb);
if ( ! validate("dir:" + searchDir, sc, name, deployState.getDeployLogger(), cfgDir)) {
return;
}
@@ -126,12 +135,36 @@ public class RankSetupValidator extends Validator {
RankingConstantsConfig rcc = new RankingConstantsConfig(rccb);
writeConfig(dir, RankingConstantsConfig.getDefName() + ".cfg", rcc);
+ OnnxModelsConfig.Builder omcb = new OnnxModelsConfig.Builder();
+ ((OnnxModelsConfig.Producer) producer).getConfig(omcb);
+ OnnxModelsConfig omc = new OnnxModelsConfig(omcb);
+ writeConfig(dir, OnnxModelsConfig.getDefName() + ".cfg", omc);
+
ImportedFieldsConfig.Builder ifcb = new ImportedFieldsConfig.Builder();
((ImportedFieldsConfig.Producer) producer).getConfig(ifcb);
ImportedFieldsConfig ifc = new ImportedFieldsConfig(ifcb);
writeConfig(dir, ImportedFieldsConfig.getDefName() + ".cfg", ifc);
}
+ private void writeExtraVerifyRanksetupConfig(String dir, DocumentDatabase db) throws IOException {
+ String configName = "verify-ranksetup.cfg";
+
+ // Assist verify-ranksetup in finding the actual ONNX model files
+ Map<String, OnnxModel> models = db.getDerivedConfiguration().getSearch().onnxModels().asMap();
+ if (models.values().size() > 0) {
+ ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults
+ String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir());
+ List<String> config = new ArrayList<>(models.values().size() * 2);
+ for (OnnxModel model : models.values()) {
+ String modelFilename = Paths.get(model.getFileName()).getFileName().toString();
+ String modelPath = Paths.get(fileRefDir, model.getFileReference(), modelFilename).toString();
+ config.add(String.format("file[%d].ref \"%s\"", config.size() / 2, model.getFileReference()));
+ config.add(String.format("file[%d].path \"%s\"", config.size() / 2, modelPath));
+ }
+ IOUtils.writeFile(dir + configName, StringUtilities.implodeMultiline(config), false);
+ }
+ }
+
private static void writeConfig(String dir, String configName, ConfigInstance config) throws IOException {
IOUtils.writeFile(dir + configName, StringUtilities.implodeMultiline(ConfigInstance.serialize(config)), false);
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java
index 58d608ec9f9..3f03df0107e 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java
@@ -38,8 +38,10 @@ public abstract class AbstractSearchCluster extends AbstractConfigProducer
}
public void prepareToDistributeFiles(List<SearchNode> backends) {
- for (SchemaSpec sds : localSDS)
+ for (SchemaSpec sds : localSDS) {
sds.getSearchDefinition().getSearch().rankingConstants().sendTo(backends);
+ sds.getSearchDefinition().getSearch().onnxModels().sendTo(backends);
+ }
}
public void addDocumentNames(NamedSchema searchDefinition) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java b/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java
index 74b9e7309c8..57ae4cf3a5b 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java
@@ -10,6 +10,7 @@ import com.yahoo.vespa.config.search.IndexschemaConfig;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.SummaryConfig;
import com.yahoo.vespa.config.search.SummarymapConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.config.search.summary.JuniperrcConfig;
import com.yahoo.vespa.configdefinition.IlscriptsConfig;
@@ -25,6 +26,7 @@ public class DocumentDatabase extends AbstractConfigProducer implements
AttributesConfig.Producer,
RankProfilesConfig.Producer,
RankingConstantsConfig.Producer,
+ OnnxModelsConfig.Producer,
IndexschemaConfig.Producer,
JuniperrcConfig.Producer,
SummarymapConfig.Producer,
@@ -78,6 +80,11 @@ public class DocumentDatabase extends AbstractConfigProducer implements
}
@Override
+ public void getConfig(OnnxModelsConfig.Builder builder) {
+ derivedCfg.getRankProfileList().getConfig(builder);
+ }
+
+ @Override
public void getConfig(IndexschemaConfig.Builder builder) {
derivedCfg.getIndexSchema().getConfig(builder);
}