diff options
Diffstat (limited to 'config-model/src/main')
6 files changed, 200 insertions, 23 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index b7b18887dd8..c2fb2107604 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -5,7 +5,10 @@ import com.yahoo.config.FileReference; import com.yahoo.vespa.model.AbstractService; import com.yahoo.vespa.model.utils.FileSender; +import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; +import java.util.List; import java.util.Objects; /** @@ -20,6 +23,8 @@ public class OnnxModel { private final String name; private String path = null; private String fileReference = ""; + private List<OnnxNameMapping> inputMap = new ArrayList<>(); + private List<OnnxNameMapping> outputMap = new ArrayList<>(); public PathType getPathType() { return pathType; @@ -49,6 +54,18 @@ public class OnnxModel { this.pathType = PathType.URI; } + public void addInputNameMapping(String onnxName, String vespaName) { + Objects.requireNonNull(onnxName, "Onnx name cannot be null"); + Objects.requireNonNull(vespaName, "Vespa name cannot be null"); + this.inputMap.add(new OnnxNameMapping(onnxName, vespaName)); + } + + public void addOutputNameMapping(String onnxName, String vespaName) { + Objects.requireNonNull(onnxName, "Onnx name cannot be null"); + Objects.requireNonNull(vespaName, "Vespa name cannot be null"); + this.outputMap.add(new OnnxNameMapping(onnxName, vespaName)); + } + /** Initiate sending of this constant to some services over file distribution */ public void sendTo(Collection<? extends AbstractService> services) { FileReference reference = (pathType == OnnxModel.PathType.FILE) @@ -62,6 +79,9 @@ public class OnnxModel { public String getUri() { return path; } public String getFileReference() { return fileReference; } + public List<OnnxNameMapping> getInputMap() { return Collections.unmodifiableList(inputMap); } + public List<OnnxNameMapping> getOutputMap() { return Collections.unmodifiableList(outputMap); } + public void validate() { if (path == null || path.isEmpty()) throw new IllegalArgumentException("ONNX models must have a file or uri."); @@ -76,4 +96,17 @@ public class OnnxModel { return b.toString(); } + public static class OnnxNameMapping { + private String onnxName; + private String vespaName; + + private OnnxNameMapping(String onnxName, String vespaName) { + this.onnxName = onnxName; + this.vespaName = vespaName; + } + public String getOnnxName() { return onnxName; } + public String getVespaName() { return vespaName; } + public void setVespaName(String vespaName) { this.vespaName = vespaName; } + } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java index 87663ac79a3..1cc33664e8c 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java @@ -27,6 +27,10 @@ public class OnnxModels { return models.get(name); } + public boolean has(String name) { + return models.containsKey(name); + } + public Map<String, OnnxModel> asMap() { return Collections.unmodifiableMap(models); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 00076c84532..84442fedc48 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -122,10 +122,14 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ for (OnnxModel model : onnxModels.asMap().values()) { if ("".equals(model.getFileReference())) log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way - else - builder.model(new OnnxModelsConfig.Model.Builder() - .name(model.getName()) - .fileref(model.getFileReference())); + else { + OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder(); + modelBuilder.name(model.getName()); + modelBuilder.fileref(model.getFileReference()); + model.getInputMap().forEach(mapper -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(mapper.getOnnxName()).source(mapper.getVespaName()))); + model.getOutputMap().forEach(mapper -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(mapper.getOnnxName()).as(mapper.getVespaName()))); + builder.model(modelBuilder); + } } } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index c3c10139684..87eaaf0387a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -8,8 +8,10 @@ import com.yahoo.compress.Compressor; import com.yahoo.config.model.api.ModelContext; import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.OnnxModel; import com.yahoo.searchdefinition.document.RankType; import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; @@ -20,6 +22,7 @@ import com.yahoo.vespa.config.search.RankProfilesConfig; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.LinkedHashSet; @@ -37,10 +40,10 @@ public class RawRankProfile implements RankProfilesConfig.Producer { /** A reusable compressor with default settings */ private static final Compressor compressor = new Compressor(); - + private final String keyEndMarker = "\r="; private final String valueEndMarker = "\r\n"; - + // TODO: These are to expose coupling between the strings used here and elsewhere public final static String summaryFeatureFefPropertyPrefix = "vespa.summary.feature"; public final static String rankFeatureFefPropertyPrefix = "vespa.dump.feature"; @@ -63,7 +66,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, AttributeFields attributeFields) { this(rankProfile, queryProfiles, importedModels, attributeFields, new TestProperties()); } - + private Compressor.Compression compress(List<Pair<String, String>> properties) { StringBuilder b = new StringBuilder(); for (Pair<String, String> property : properties) @@ -109,12 +112,12 @@ public class RawRankProfile implements RankProfilesConfig.Producer { b.fef(fefB); } - /** + /** * Returns the properties of this as an unmodifiable list. * Note: This method is expensive. */ public List<Pair<String, String>> configProperties() { return decompress(compressedProperties); } - + private static class Deriver { /** @@ -194,6 +197,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { ignoreDefaultRankFeatures = rankProfile.getIgnoreDefaultRankFeatures(); rankProperties = new ArrayList<>(rankProfile.getRankProperties()); derivePropertiesAndSummaryFeaturesFromFunctions(rankProfile.getFunctions()); + deriveOnnxModelFunctionsAndSummaryFeatures(rankProfile); } private void derivePropertiesAndSummaryFeaturesFromFunctions(Map<String, RankProfile.RankingExpressionFunction> functions) { @@ -433,6 +437,40 @@ public class RawRankProfile implements RankProfilesConfig.Producer { return properties; } + private void deriveOnnxModelFunctionsAndSummaryFeatures(RankProfile rankProfile) { + if (rankProfile.getSearch() == null) return; + if (rankProfile.getSearch().onnxModels().asMap().isEmpty()) return; + replaceOnnxFunctionInputs(rankProfile); + replaceImplicitOnnxConfigSummaryFeatures(rankProfile); + } + + private void replaceOnnxFunctionInputs(RankProfile rankProfile) { + Set<String> functionNames = rankProfile.getFunctions().keySet(); + if (functionNames.isEmpty()) return; + for (OnnxModel onnxModel: rankProfile.getSearch().onnxModels().asMap().values()) { + for (OnnxModel.OnnxNameMapping mapping : onnxModel.getInputMap()) { + String source = mapping.getVespaName(); + if (functionNames.contains(source)) { + mapping.setVespaName("rankingExpression(" + source + ")"); + } + } + } + } + + private void replaceImplicitOnnxConfigSummaryFeatures(RankProfile rankProfile) { + if (summaryFeatures == null || summaryFeatures.isEmpty()) return; + Set<ReferenceNode> replacedSummaryFeatures = new HashSet<>(); + for (Iterator<ReferenceNode> i = summaryFeatures.iterator(); i.hasNext(); ) { + ReferenceNode referenceNode = i.next(); + ReferenceNode replacedNode = OnnxModelTransformer.transformFeature(referenceNode, rankProfile.getSearch()); + if (referenceNode != replacedNode) { + replacedSummaryFeatures.add(replacedNode); + i.remove(); + } + } + summaryFeatures.addAll(replacedSummaryFeatures); + } + } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java index d8ffbd7d030..e1ad003e5bd 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java @@ -1,6 +1,7 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; +import com.yahoo.searchdefinition.ImmutableSearch; import com.yahoo.searchdefinition.OnnxModel; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; @@ -12,9 +13,8 @@ import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import java.util.List; /** - * Transforms instances of the onnxModel(model-path, output) ranking feature - * by adding the model file to file distribution and rewriting this feature - * to point to the generated configuration. + * Transforms instances of the onnxModel ranking feature and generates + * ONNX configuration if necessary. * * @author lesters */ @@ -31,27 +31,66 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans } private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { + if (context.rankProfile() == null) return feature; + if (context.rankProfile().getSearch() == null) return feature; + return transformFeature(feature, context.rankProfile().getSearch()); + } + + public static ReferenceNode transformFeature(ReferenceNode feature, ImmutableSearch search) { if (!feature.getName().equals("onnxModel")) return feature; Arguments arguments = feature.getArguments(); if (arguments.isEmpty()) - throw new IllegalArgumentException("An onnxModel feature must take an argument pointing to the ONNX file."); + throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a " + + "onnx-model config or a ONNX file."); if (arguments.expressions().size() > 2) throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments."); - String path = asString(arguments.expressions().get(0)); - String name = toModelName(path); - String output = arguments.expressions().size() > 1 ? asString(arguments.expressions().get(1)) : null; - // Validation that the file actually exists is handled when the file is added to file distribution. // Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator. - // Add model to config - context.rankProfile().getSearch().onnxModels().add(new OnnxModel(name, path)); + String modelConfigName; + OnnxModel onnxModel; + if (arguments.expressions().get(0) instanceof ReferenceNode) { + modelConfigName = arguments.expressions().get(0).toString(); + onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { + throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found"); + } + } else if (arguments.expressions().get(0) instanceof ConstantNode) { + String path = asString(arguments.expressions().get(0)); + modelConfigName = asValidIdentifier(path); + onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { + onnxModel = new OnnxModel(modelConfigName, path); + search.onnxModels().add(onnxModel); + } + } else { + throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'"); + } + + String output = null; + if (feature.getOutput() != null) { + output = feature.getOutput(); + if ( ! hasOutputMapping(onnxModel, output)) { + onnxModel.addOutputNameMapping(output, output); + } + } else if (arguments.expressions().size() > 1) { + String name = asString(arguments.expressions().get(1)); + output = asValidIdentifier(name); + if ( ! hasOutputMapping(onnxModel, output)) { + onnxModel.addOutputNameMapping(name, output); + } + } // Replace feature with name of config - ExpressionNode argument = new ReferenceNode(name); + ExpressionNode argument = new ReferenceNode(modelConfigName); return new ReferenceNode("onnxModel", List.of(argument), output); + + } + + private static boolean hasOutputMapping(OnnxModel onnxModel, String as) { + return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as)); } private static String asString(ExpressionNode node) { @@ -71,8 +110,8 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans return c == '\'' || c == '"'; } - public static String toModelName(String path) { - return path.replaceAll("[^\\w\\d\\$@_]", "_"); + private static String asValidIdentifier(String str) { + return str.replaceAll("[^\\w\\d\\$@_]", "_"); } } diff --git a/config-model/src/main/javacc/SDParser.jj b/config-model/src/main/javacc/SDParser.jj index ad359a6a943..bf752b39fa8 100644 --- a/config-model/src/main/javacc/SDParser.jj +++ b/config-model/src/main/javacc/SDParser.jj @@ -32,6 +32,7 @@ import com.yahoo.searchdefinition.document.*; import com.yahoo.searchdefinition.document.annotation.SDAnnotationType; import com.yahoo.searchdefinition.document.annotation.TemporaryAnnotationReferenceDataType; import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchdefinition.OnnxModel; import com.yahoo.searchdefinition.Index; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.DocumentsOnlyRankProfile; @@ -231,6 +232,7 @@ TOKEN : | < SUBSTRING: "substring" > | < SUFFIX: "suffix" > | < CONSTANT: "constant"> +| < ONNXMODEL: "onnx-model"> | < RANKPROFILE: "rank-profile" > | < RANKDEGRADATIONFREQ: "rank-degradation-frequency" > | < RANKDEGRADATION: "rank-degradation" > @@ -379,6 +381,8 @@ TOKEN : | < LESSTHAN: "<" > | < GREATERTHAN: ">" > | < VARIABLE: "$" <IDENTIFIER> > +| < ONNX_INPUT_SL: "input" (" ")* (<IDENTIFIER>|<QUOTEDSTRING>) (" ")* ":" (" ")* (~["\n"])* ("\n")? > +| < ONNX_OUTPUT_SL: "output" (" ")* (<IDENTIFIER>|<QUOTEDSTRING>) (" ")* ":" (" ")* (~["\n"])* ("\n")? > } // Declare a special skip token for comments. @@ -451,7 +455,8 @@ Object rootSchemaItem(Search search) : { } | structOutside(search) | annotationOutside(search) | fieldSet(search) - | importField(search) ) + | importField(search) + | onnxModel(search) ) { return null; } } @@ -1847,6 +1852,60 @@ void hnswIndexBody(HnswIndexParams.Builder params) : } /** + * Consumes a onnx-model block of a search element. + * + * @param search The search object to add content to. + */ +void onnxModel(Search search) : +{ + String name; + OnnxModel onnxModel; +} +{ + ( <ONNXMODEL> name = identifier() + { + onnxModel = new OnnxModel(name); + } + lbrace() (onnxModelItem(onnxModel) (<NL>)*)+ <RBRACE> ) + { + if (documentsOnly) return; + search.onnxModels().add(onnxModel); + } +} + +/** + * This rule consumes an onnx-model block. + * + * @param onnxModel The onnxModel to modify. + * @return Null. + */ +Object onnxModelItem(OnnxModel onnxModel) : +{ + String path = null; +} +{ + ( + (<FILE> <COLON> path = filePath() { } (<NL>)*) { onnxModel.setFileName(path); } | + (<URI> <COLON> path = uriPath() { } (<NL>)*) { onnxModel.setUri(path); } | + (<ONNX_INPUT_SL>) { + String name = token.image.substring(5, token.image.lastIndexOf(":")).trim(); + if (name.startsWith("\"")) { name = name.substring(1, name.length() - 1); } + String source = token.image.substring(token.image.lastIndexOf(":") + 1).trim(); + onnxModel.addInputNameMapping(name, source); + } | + (<ONNX_OUTPUT_SL>) { + String name = token.image.substring(6, token.image.lastIndexOf(":")).trim(); + if (name.startsWith("\"")) { name = name.substring(1, name.length() - 1); } + String as = token.image.substring(token.image.lastIndexOf(":") + 1).trim(); + onnxModel.addOutputNameMapping(name, as); + } + ) + { + return null; + } +} + +/** * Consumes a constant block of a search element. * * @param search The search object to add content to. |