diff options
author | Lester Solbakken <lesters@oath.com> | 2020-09-02 15:33:12 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-09-02 15:33:12 +0200 |
commit | c80e560d472bb37643b0eadac9e4915b16d11f3d (patch) | |
tree | 418f7faa37ac08f3242a872c8cc618279a9e8fec /config-model/src/main/java/com | |
parent | f7551282dae23a3d079c406bd15b23025bbe0f68 (diff) |
Add config generation for models evaluated via ONNXRT
Diffstat (limited to 'config-model/src/main/java/com')
10 files changed, 267 insertions, 2 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSearch.java b/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSearch.java index 0b9447d05f5..6ac73ad45a9 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSearch.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/ImmutableSearch.java @@ -28,6 +28,7 @@ public interface ImmutableSearch { Reader getRankingExpression(String fileName); ApplicationPackage applicationPackage(); RankingConstants rankingConstants(); + OnnxModels onnxModels(); Stream<ImmutableSDField> allImportedFields(); ImmutableSDField getField(String name); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java new file mode 100644 index 00000000000..b7b18887dd8 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -0,0 +1,79 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition; + +import com.yahoo.config.FileReference; +import com.yahoo.vespa.model.AbstractService; +import com.yahoo.vespa.model.utils.FileSender; + +import java.util.Collection; +import java.util.Objects; + +/** + * A global ONNX model distributed using file distribution, similar to ranking constants. + * + * @author lesters + */ +public class OnnxModel { + + public enum PathType {FILE, URI}; + + private final String name; + private String path = null; + private String fileReference = ""; + + public PathType getPathType() { + return pathType; + } + + private PathType pathType = PathType.FILE; + + public OnnxModel(String name) { + this.name = name; + } + + public OnnxModel(String name, String fileName) { + this(name); + this.path = fileName; + validate(); + } + + public void setFileName(String fileName) { + Objects.requireNonNull(fileName, "Filename cannot be null"); + this.path = fileName; + this.pathType = PathType.FILE; + } + + public void setUri(String uri) { + Objects.requireNonNull(uri, "uri cannot be null"); + this.path = uri; + this.pathType = PathType.URI; + } + + /** Initiate sending of this constant to some services over file distribution */ + public void sendTo(Collection<? extends AbstractService> services) { + FileReference reference = (pathType == OnnxModel.PathType.FILE) + ? FileSender.sendFileToServices(path, services) + : FileSender.sendUriToServices(path, services); + this.fileReference = reference.value(); + } + + public String getName() { return name; } + public String getFileName() { return path; } + public String getUri() { return path; } + public String getFileReference() { return fileReference; } + + public void validate() { + if (path == null || path.isEmpty()) + throw new IllegalArgumentException("ONNX models must have a file or uri."); + } + + public String toString() { + StringBuilder b = new StringBuilder(); + b.append("onnx-model '").append(name) + .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path) + .append("' with ref '").append(fileReference) + .append("'"); + return b.toString(); + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java new file mode 100644 index 00000000000..87663ac79a3 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java @@ -0,0 +1,39 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition; + +import com.yahoo.vespa.model.AbstractService; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * ONNX models tied to a search definition. + * + * @author lesters + */ +public class OnnxModels { + + private final Map<String, OnnxModel> models = new HashMap<>(); + + public void add(OnnxModel model) { + model.validate(); + String name = model.getName(); + models.put(name, model); + } + + public OnnxModel get(String name) { + return models.get(name); + } + + public Map<String, OnnxModel> asMap() { + return Collections.unmodifiableMap(models); + } + + /** Initiate sending of these models to some services over file distribution */ + public void sendTo(Collection<? extends AbstractService> services) { + models.values().forEach(model -> model.sendTo(services)); + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java index 3a60a75f75f..64c5590b689 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java @@ -79,6 +79,9 @@ public class Search implements ImmutableSearch { /** Ranking constants of this */ private final RankingConstants rankingConstants = new RankingConstants(); + /** Onnx models of this */ + private final OnnxModels onnxModels = new OnnxModels(); + private Optional<TemporaryImportedFields> temporaryImportedFields = Optional.of(new TemporaryImportedFields()); private Optional<ImportedFields> importedFields = Optional.empty(); @@ -159,6 +162,9 @@ public class Search implements ImmutableSearch { @Override public RankingConstants rankingConstants() { return rankingConstants; } + @Override + public OnnxModels onnxModels() { return onnxModels; } + public Optional<TemporaryImportedFields> temporaryImportedFields() { return temporaryImportedFields; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index f6d4889b55e..00076c84532 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -4,12 +4,15 @@ package com.yahoo.searchdefinition.derived; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; import com.yahoo.config.model.api.ModelContext; import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.OnnxModels; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchdefinition.RankingConstants; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.Search; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.model.AbstractService; @@ -22,17 +25,21 @@ import java.util.logging.Logger; * * @author bratseth */ -public class RankProfileList extends Derived implements RankProfilesConfig.Producer, RankingConstantsConfig.Producer { +public class RankProfileList extends Derived implements RankProfilesConfig.Producer, + RankingConstantsConfig.Producer, + OnnxModelsConfig.Producer { private static final Logger log = Logger.getLogger(RankProfileList.class.getName()); private final Map<String, RawRankProfile> rankProfiles = new java.util.LinkedHashMap<>(); private final RankingConstants rankingConstants; + private final OnnxModels onnxModels; public static RankProfileList empty = new RankProfileList(); private RankProfileList() { this.rankingConstants = new RankingConstants(); + this.onnxModels = new OnnxModels(); } /** @@ -51,6 +58,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ setName(search == null ? "default" : search.getName()); this.rankingConstants = rankingConstants; deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields, deployProperties); + this.onnxModels = search == null ? new OnnxModels() : search.onnxModels(); // as ONNX models come from parsing rank expressions } private void deriveRankProfiles(RankProfileRegistry rankProfileRegistry, @@ -109,4 +117,15 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ } } + @Override + public void getConfig(OnnxModelsConfig.Builder builder) { + for (OnnxModel model : onnxModels.asMap().values()) { + if ("".equals(model.getFileReference())) + log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way + else + builder.model(new OnnxModelsConfig.Model.Builder() + .name(model.getName()) + .fileref(model.getFileReference())); + } + } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java index a6707ec7ac0..a723be8b478 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java @@ -26,6 +26,7 @@ public class ExpressionTransforms { transforms = ImmutableList.of(new TensorFlowFeatureConverter(), new OnnxFeatureConverter(), + new OnnxModelTransformer(), new XgboostFeatureConverter(), new LightGBMFeatureConverter(), new ConstantDereferencer(), diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java new file mode 100644 index 00000000000..d8ffbd7d030 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java @@ -0,0 +1,78 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.expressiontransforms; + +import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; + +import java.util.List; + +/** + * Transforms instances of the onnxModel(model-path, output) ranking feature + * by adding the model file to file distribution and rewriting this feature + * to point to the generated configuration. + * + * @author lesters + */ +public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTransformContext> { + + @Override + public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { + if (node instanceof ReferenceNode) + return transformFeature((ReferenceNode) node, context); + else if (node instanceof CompositeNode) + return super.transformChildren((CompositeNode) node, context); + else + return node; + } + + private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { + if (!feature.getName().equals("onnxModel")) return feature; + + Arguments arguments = feature.getArguments(); + if (arguments.isEmpty()) + throw new IllegalArgumentException("An onnxModel feature must take an argument pointing to the ONNX file."); + if (arguments.expressions().size() > 2) + throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments."); + + String path = asString(arguments.expressions().get(0)); + String name = toModelName(path); + String output = arguments.expressions().size() > 1 ? asString(arguments.expressions().get(1)) : null; + + // Validation that the file actually exists is handled when the file is added to file distribution. + // Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator. + + // Add model to config + context.rankProfile().getSearch().onnxModels().add(new OnnxModel(name, path)); + + // Replace feature with name of config + ExpressionNode argument = new ReferenceNode(name); + return new ReferenceNode("onnxModel", List.of(argument), output); + } + + private static String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); + } + + private static String stripQuotes(String s) { + if ( ! isQuoteSign(s.codePointAt(0))) return s; + if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) + throw new IllegalArgumentException("argument [" + s + "] is missing endquote"); + return s.substring(1, s.length()-1); + } + + private static boolean isQuoteSign(int c) { + return c == '\'' || c == '"'; + } + + public static String toModelName(String path) { + return path.replaceAll("[^\\w\\d\\$@_]", "_"); + } + +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java index 96dc312abd7..c6c7969e466 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java @@ -1,21 +1,25 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.application.validation; +import com.yahoo.cloud.config.ConfigserverConfig; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.io.IOUtils; import com.yahoo.log.InvalidLogFormatException; import java.util.logging.Level; import com.yahoo.log.LogMessage; +import com.yahoo.searchdefinition.OnnxModel; import com.yahoo.yolean.Exceptions; import com.yahoo.system.ProcessExecuter; import com.yahoo.text.StringUtilities; import com.yahoo.vespa.config.search.AttributesConfig; import com.yahoo.collections.Pair; import com.yahoo.config.ConfigInstance; +import com.yahoo.vespa.defaults.Defaults; import com.yahoo.vespa.config.search.ImportedFieldsConfig; import com.yahoo.vespa.config.search.IndexschemaConfig; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.producer.AbstractConfigProducer; import com.yahoo.vespa.model.VespaModel; @@ -27,9 +31,13 @@ import com.yahoo.vespa.model.search.SearchCluster; import java.io.File; import java.io.IOException; import java.nio.file.Files; +import java.nio.file.Paths; import java.time.Duration; import java.time.Instant; import java.util.logging.Logger; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; /** * Validate rank setup for all search clusters (rank-profiles, index-schema, attributes configs), validating done @@ -66,6 +74,7 @@ public class RankSetupValidator extends Validator { final String name = docDb.getDerivedConfiguration().getSearch().getName(); String searchDir = clusterDir + name + "/"; writeConfigs(searchDir, docDb); + writeExtraVerifyRanksetupConfig(searchDir, docDb); if ( ! validate("dir:" + searchDir, sc, name, deployState.getDeployLogger(), cfgDir)) { return; } @@ -126,12 +135,36 @@ public class RankSetupValidator extends Validator { RankingConstantsConfig rcc = new RankingConstantsConfig(rccb); writeConfig(dir, RankingConstantsConfig.getDefName() + ".cfg", rcc); + OnnxModelsConfig.Builder omcb = new OnnxModelsConfig.Builder(); + ((OnnxModelsConfig.Producer) producer).getConfig(omcb); + OnnxModelsConfig omc = new OnnxModelsConfig(omcb); + writeConfig(dir, OnnxModelsConfig.getDefName() + ".cfg", omc); + ImportedFieldsConfig.Builder ifcb = new ImportedFieldsConfig.Builder(); ((ImportedFieldsConfig.Producer) producer).getConfig(ifcb); ImportedFieldsConfig ifc = new ImportedFieldsConfig(ifcb); writeConfig(dir, ImportedFieldsConfig.getDefName() + ".cfg", ifc); } + private void writeExtraVerifyRanksetupConfig(String dir, DocumentDatabase db) throws IOException { + String configName = "verify-ranksetup.cfg"; + + // Assist verify-ranksetup in finding the actual ONNX model files + Map<String, OnnxModel> models = db.getDerivedConfiguration().getSearch().onnxModels().asMap(); + if (models.values().size() > 0) { + ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults + String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir()); + List<String> config = new ArrayList<>(models.values().size() * 2); + for (OnnxModel model : models.values()) { + String modelFilename = Paths.get(model.getFileName()).getFileName().toString(); + String modelPath = Paths.get(fileRefDir, model.getFileReference(), modelFilename).toString(); + config.add(String.format("file[%d].ref \"%s\"", config.size() / 2, model.getFileReference())); + config.add(String.format("file[%d].path \"%s\"", config.size() / 2, modelPath)); + } + IOUtils.writeFile(dir + configName, StringUtilities.implodeMultiline(config), false); + } + } + private static void writeConfig(String dir, String configName, ConfigInstance config) throws IOException { IOUtils.writeFile(dir + configName, StringUtilities.implodeMultiline(ConfigInstance.serialize(config)), false); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java index 58d608ec9f9..3f03df0107e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java @@ -38,8 +38,10 @@ public abstract class AbstractSearchCluster extends AbstractConfigProducer } public void prepareToDistributeFiles(List<SearchNode> backends) { - for (SchemaSpec sds : localSDS) + for (SchemaSpec sds : localSDS) { sds.getSearchDefinition().getSearch().rankingConstants().sendTo(backends); + sds.getSearchDefinition().getSearch().onnxModels().sendTo(backends); + } } public void addDocumentNames(NamedSchema searchDefinition) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java b/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java index 74b9e7309c8..57ae4cf3a5b 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java @@ -10,6 +10,7 @@ import com.yahoo.vespa.config.search.IndexschemaConfig; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.SummaryConfig; import com.yahoo.vespa.config.search.SummarymapConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.config.search.summary.JuniperrcConfig; import com.yahoo.vespa.configdefinition.IlscriptsConfig; @@ -25,6 +26,7 @@ public class DocumentDatabase extends AbstractConfigProducer implements AttributesConfig.Producer, RankProfilesConfig.Producer, RankingConstantsConfig.Producer, + OnnxModelsConfig.Producer, IndexschemaConfig.Producer, JuniperrcConfig.Producer, SummarymapConfig.Producer, @@ -78,6 +80,11 @@ public class DocumentDatabase extends AbstractConfigProducer implements } @Override + public void getConfig(OnnxModelsConfig.Builder builder) { + derivedCfg.getRankProfileList().getConfig(builder); + } + + @Override public void getConfig(IndexschemaConfig.Builder builder) { derivedCfg.getIndexSchema().getConfig(builder); } |