diff options
author | Lester Solbakken <lesters@oath.com> | 2020-09-18 14:58:57 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-09-18 14:58:57 +0200 |
commit | be544696d4b70ee186dc80f250bda7d99cd0e20f (patch) | |
tree | e200be1130b49b2d626f00a681eec84b4686ab05 | |
parent | bf3ca4359b94aff539fc79b80b4caac66225a028 (diff) |
Add explicit config for onnx models
12 files changed, 329 insertions, 77 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index b7b18887dd8..c2fb2107604 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -5,7 +5,10 @@ import com.yahoo.config.FileReference; import com.yahoo.vespa.model.AbstractService; import com.yahoo.vespa.model.utils.FileSender; +import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; +import java.util.List; import java.util.Objects; /** @@ -20,6 +23,8 @@ public class OnnxModel { private final String name; private String path = null; private String fileReference = ""; + private List<OnnxNameMapping> inputMap = new ArrayList<>(); + private List<OnnxNameMapping> outputMap = new ArrayList<>(); public PathType getPathType() { return pathType; @@ -49,6 +54,18 @@ public class OnnxModel { this.pathType = PathType.URI; } + public void addInputNameMapping(String onnxName, String vespaName) { + Objects.requireNonNull(onnxName, "Onnx name cannot be null"); + Objects.requireNonNull(vespaName, "Vespa name cannot be null"); + this.inputMap.add(new OnnxNameMapping(onnxName, vespaName)); + } + + public void addOutputNameMapping(String onnxName, String vespaName) { + Objects.requireNonNull(onnxName, "Onnx name cannot be null"); + Objects.requireNonNull(vespaName, "Vespa name cannot be null"); + this.outputMap.add(new OnnxNameMapping(onnxName, vespaName)); + } + /** Initiate sending of this constant to some services over file distribution */ public void sendTo(Collection<? extends AbstractService> services) { FileReference reference = (pathType == OnnxModel.PathType.FILE) @@ -62,6 +79,9 @@ public class OnnxModel { public String getUri() { return path; } public String getFileReference() { return fileReference; } + public List<OnnxNameMapping> getInputMap() { return Collections.unmodifiableList(inputMap); } + public List<OnnxNameMapping> getOutputMap() { return Collections.unmodifiableList(outputMap); } + public void validate() { if (path == null || path.isEmpty()) throw new IllegalArgumentException("ONNX models must have a file or uri."); @@ -76,4 +96,17 @@ public class OnnxModel { return b.toString(); } + public static class OnnxNameMapping { + private String onnxName; + private String vespaName; + + private OnnxNameMapping(String onnxName, String vespaName) { + this.onnxName = onnxName; + this.vespaName = vespaName; + } + public String getOnnxName() { return onnxName; } + public String getVespaName() { return vespaName; } + public void setVespaName(String vespaName) { this.vespaName = vespaName; } + } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java index 87663ac79a3..1cc33664e8c 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java @@ -27,6 +27,10 @@ public class OnnxModels { return models.get(name); } + public boolean has(String name) { + return models.containsKey(name); + } + public Map<String, OnnxModel> asMap() { return Collections.unmodifiableMap(models); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 00076c84532..84442fedc48 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -122,10 +122,14 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ for (OnnxModel model : onnxModels.asMap().values()) { if ("".equals(model.getFileReference())) log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way - else - builder.model(new OnnxModelsConfig.Model.Builder() - .name(model.getName()) - .fileref(model.getFileReference())); + else { + OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder(); + modelBuilder.name(model.getName()); + modelBuilder.fileref(model.getFileReference()); + model.getInputMap().forEach(mapper -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(mapper.getOnnxName()).source(mapper.getVespaName()))); + model.getOutputMap().forEach(mapper -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(mapper.getOnnxName()).as(mapper.getVespaName()))); + builder.model(modelBuilder); + } } } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index c3c10139684..87eaaf0387a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -8,8 +8,10 @@ import com.yahoo.compress.Compressor; import com.yahoo.config.model.api.ModelContext; import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.OnnxModel; import com.yahoo.searchdefinition.document.RankType; import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; @@ -20,6 +22,7 @@ import com.yahoo.vespa.config.search.RankProfilesConfig; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.LinkedHashSet; @@ -37,10 +40,10 @@ public class RawRankProfile implements RankProfilesConfig.Producer { /** A reusable compressor with default settings */ private static final Compressor compressor = new Compressor(); - + private final String keyEndMarker = "\r="; private final String valueEndMarker = "\r\n"; - + // TODO: These are to expose coupling between the strings used here and elsewhere public final static String summaryFeatureFefPropertyPrefix = "vespa.summary.feature"; public final static String rankFeatureFefPropertyPrefix = "vespa.dump.feature"; @@ -63,7 +66,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, AttributeFields attributeFields) { this(rankProfile, queryProfiles, importedModels, attributeFields, new TestProperties()); } - + private Compressor.Compression compress(List<Pair<String, String>> properties) { StringBuilder b = new StringBuilder(); for (Pair<String, String> property : properties) @@ -109,12 +112,12 @@ public class RawRankProfile implements RankProfilesConfig.Producer { b.fef(fefB); } - /** + /** * Returns the properties of this as an unmodifiable list. * Note: This method is expensive. */ public List<Pair<String, String>> configProperties() { return decompress(compressedProperties); } - + private static class Deriver { /** @@ -194,6 +197,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { ignoreDefaultRankFeatures = rankProfile.getIgnoreDefaultRankFeatures(); rankProperties = new ArrayList<>(rankProfile.getRankProperties()); derivePropertiesAndSummaryFeaturesFromFunctions(rankProfile.getFunctions()); + deriveOnnxModelFunctionsAndSummaryFeatures(rankProfile); } private void derivePropertiesAndSummaryFeaturesFromFunctions(Map<String, RankProfile.RankingExpressionFunction> functions) { @@ -433,6 +437,40 @@ public class RawRankProfile implements RankProfilesConfig.Producer { return properties; } + private void deriveOnnxModelFunctionsAndSummaryFeatures(RankProfile rankProfile) { + if (rankProfile.getSearch() == null) return; + if (rankProfile.getSearch().onnxModels().asMap().isEmpty()) return; + replaceOnnxFunctionInputs(rankProfile); + replaceImplicitOnnxConfigSummaryFeatures(rankProfile); + } + + private void replaceOnnxFunctionInputs(RankProfile rankProfile) { + Set<String> functionNames = rankProfile.getFunctions().keySet(); + if (functionNames.isEmpty()) return; + for (OnnxModel onnxModel: rankProfile.getSearch().onnxModels().asMap().values()) { + for (OnnxModel.OnnxNameMapping mapping : onnxModel.getInputMap()) { + String source = mapping.getVespaName(); + if (functionNames.contains(source)) { + mapping.setVespaName("rankingExpression(" + source + ")"); + } + } + } + } + + private void replaceImplicitOnnxConfigSummaryFeatures(RankProfile rankProfile) { + if (summaryFeatures == null || summaryFeatures.isEmpty()) return; + Set<ReferenceNode> replacedSummaryFeatures = new HashSet<>(); + for (Iterator<ReferenceNode> i = summaryFeatures.iterator(); i.hasNext(); ) { + ReferenceNode referenceNode = i.next(); + ReferenceNode replacedNode = OnnxModelTransformer.transformFeature(referenceNode, rankProfile.getSearch()); + if (referenceNode != replacedNode) { + replacedSummaryFeatures.add(replacedNode); + i.remove(); + } + } + summaryFeatures.addAll(replacedSummaryFeatures); + } + } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java index d8ffbd7d030..e1ad003e5bd 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java @@ -1,6 +1,7 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; +import com.yahoo.searchdefinition.ImmutableSearch; import com.yahoo.searchdefinition.OnnxModel; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; @@ -12,9 +13,8 @@ import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import java.util.List; /** - * Transforms instances of the onnxModel(model-path, output) ranking feature - * by adding the model file to file distribution and rewriting this feature - * to point to the generated configuration. + * Transforms instances of the onnxModel ranking feature and generates + * ONNX configuration if necessary. * * @author lesters */ @@ -31,27 +31,66 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans } private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { + if (context.rankProfile() == null) return feature; + if (context.rankProfile().getSearch() == null) return feature; + return transformFeature(feature, context.rankProfile().getSearch()); + } + + public static ReferenceNode transformFeature(ReferenceNode feature, ImmutableSearch search) { if (!feature.getName().equals("onnxModel")) return feature; Arguments arguments = feature.getArguments(); if (arguments.isEmpty()) - throw new IllegalArgumentException("An onnxModel feature must take an argument pointing to the ONNX file."); + throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a " + + "onnx-model config or a ONNX file."); if (arguments.expressions().size() > 2) throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments."); - String path = asString(arguments.expressions().get(0)); - String name = toModelName(path); - String output = arguments.expressions().size() > 1 ? asString(arguments.expressions().get(1)) : null; - // Validation that the file actually exists is handled when the file is added to file distribution. // Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator. - // Add model to config - context.rankProfile().getSearch().onnxModels().add(new OnnxModel(name, path)); + String modelConfigName; + OnnxModel onnxModel; + if (arguments.expressions().get(0) instanceof ReferenceNode) { + modelConfigName = arguments.expressions().get(0).toString(); + onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { + throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found"); + } + } else if (arguments.expressions().get(0) instanceof ConstantNode) { + String path = asString(arguments.expressions().get(0)); + modelConfigName = asValidIdentifier(path); + onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { + onnxModel = new OnnxModel(modelConfigName, path); + search.onnxModels().add(onnxModel); + } + } else { + throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'"); + } + + String output = null; + if (feature.getOutput() != null) { + output = feature.getOutput(); + if ( ! hasOutputMapping(onnxModel, output)) { + onnxModel.addOutputNameMapping(output, output); + } + } else if (arguments.expressions().size() > 1) { + String name = asString(arguments.expressions().get(1)); + output = asValidIdentifier(name); + if ( ! hasOutputMapping(onnxModel, output)) { + onnxModel.addOutputNameMapping(name, output); + } + } // Replace feature with name of config - ExpressionNode argument = new ReferenceNode(name); + ExpressionNode argument = new ReferenceNode(modelConfigName); return new ReferenceNode("onnxModel", List.of(argument), output); + + } + + private static boolean hasOutputMapping(OnnxModel onnxModel, String as) { + return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as)); } private static String asString(ExpressionNode node) { @@ -71,8 +110,8 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans return c == '\'' || c == '"'; } - public static String toModelName(String path) { - return path.replaceAll("[^\\w\\d\\$@_]", "_"); + private static String asValidIdentifier(String str) { + return str.replaceAll("[^\\w\\d\\$@_]", "_"); } } diff --git a/config-model/src/main/javacc/SDParser.jj b/config-model/src/main/javacc/SDParser.jj index ad359a6a943..bf752b39fa8 100644 --- a/config-model/src/main/javacc/SDParser.jj +++ b/config-model/src/main/javacc/SDParser.jj @@ -32,6 +32,7 @@ import com.yahoo.searchdefinition.document.*; import com.yahoo.searchdefinition.document.annotation.SDAnnotationType; import com.yahoo.searchdefinition.document.annotation.TemporaryAnnotationReferenceDataType; import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchdefinition.OnnxModel; import com.yahoo.searchdefinition.Index; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.DocumentsOnlyRankProfile; @@ -231,6 +232,7 @@ TOKEN : | < SUBSTRING: "substring" > | < SUFFIX: "suffix" > | < CONSTANT: "constant"> +| < ONNXMODEL: "onnx-model"> | < RANKPROFILE: "rank-profile" > | < RANKDEGRADATIONFREQ: "rank-degradation-frequency" > | < RANKDEGRADATION: "rank-degradation" > @@ -379,6 +381,8 @@ TOKEN : | < LESSTHAN: "<" > | < GREATERTHAN: ">" > | < VARIABLE: "$" <IDENTIFIER> > +| < ONNX_INPUT_SL: "input" (" ")* (<IDENTIFIER>|<QUOTEDSTRING>) (" ")* ":" (" ")* (~["\n"])* ("\n")? > +| < ONNX_OUTPUT_SL: "output" (" ")* (<IDENTIFIER>|<QUOTEDSTRING>) (" ")* ":" (" ")* (~["\n"])* ("\n")? > } // Declare a special skip token for comments. @@ -451,7 +455,8 @@ Object rootSchemaItem(Search search) : { } | structOutside(search) | annotationOutside(search) | fieldSet(search) - | importField(search) ) + | importField(search) + | onnxModel(search) ) { return null; } } @@ -1847,6 +1852,60 @@ void hnswIndexBody(HnswIndexParams.Builder params) : } /** + * Consumes a onnx-model block of a search element. + * + * @param search The search object to add content to. + */ +void onnxModel(Search search) : +{ + String name; + OnnxModel onnxModel; +} +{ + ( <ONNXMODEL> name = identifier() + { + onnxModel = new OnnxModel(name); + } + lbrace() (onnxModelItem(onnxModel) (<NL>)*)+ <RBRACE> ) + { + if (documentsOnly) return; + search.onnxModels().add(onnxModel); + } +} + +/** + * This rule consumes an onnx-model block. + * + * @param onnxModel The onnxModel to modify. + * @return Null. + */ +Object onnxModelItem(OnnxModel onnxModel) : +{ + String path = null; +} +{ + ( + (<FILE> <COLON> path = filePath() { } (<NL>)*) { onnxModel.setFileName(path); } | + (<URI> <COLON> path = uriPath() { } (<NL>)*) { onnxModel.setUri(path); } | + (<ONNX_INPUT_SL>) { + String name = token.image.substring(5, token.image.lastIndexOf(":")).trim(); + if (name.startsWith("\"")) { name = name.substring(1, name.length() - 1); } + String source = token.image.substring(token.image.lastIndexOf(":") + 1).trim(); + onnxModel.addInputNameMapping(name, source); + } | + (<ONNX_OUTPUT_SL>) { + String name = token.image.substring(6, token.image.lastIndexOf(":")).trim(); + if (name.startsWith("\"")) { name = name.substring(1, name.length() - 1); } + String as = token.image.substring(token.image.lastIndexOf(":") + 1).trim(); + onnxModel.addOutputNameMapping(name, as); + } + ) + { + return null; + } +} + +/** * Consumes a constant block of a search element. * * @param search The search object to add content to. diff --git a/config-model/src/test/integration/onnx-file/files/simple.onnx b/config-model/src/test/integration/onnx-file/files/simple.onnx deleted file mode 100644 index eaa66f533da..00000000000 --- a/config-model/src/test/integration/onnx-file/files/simple.onnx +++ /dev/null @@ -1,23 +0,0 @@ - simple.py:ß -0 -query_tensor -attribute_tensormatmul"MatMul -" -matmul -bias_tensoroutput"Addsimple_scoringZ -query_tensor - - -Z" -attribute_tensor - - -Z -bias_tensor - - -b -output - - -B
\ No newline at end of file diff --git a/config-model/src/test/integration/onnx-file/searchdefinitions/test.sd b/config-model/src/test/integration/onnx-file/searchdefinitions/test.sd deleted file mode 100644 index 5ca0cd1b8bf..00000000000 --- a/config-model/src/test/integration/onnx-file/searchdefinitions/test.sd +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -search test { - document test {} - - rank-profile my_profile inherits default { - first-phase { - expression: onnxModel("files/simple.onnx", "output") - } - } - -} diff --git a/config-model/src/test/integration/onnx-model/files/constant.json b/config-model/src/test/integration/onnx-model/files/constant.json new file mode 100644 index 00000000000..63f64a73af5 --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/constant.json @@ -0,0 +1,6 @@ +{ + "cells": [ + { "address": { "d0": "0" }, "value": 2.0 }, + { "address": { "d0": "1" }, "value": 3.0 } + ] +}
\ No newline at end of file diff --git a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd new file mode 100644 index 00000000000..0f0fa694e6f --- /dev/null +++ b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd @@ -0,0 +1,70 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +search test { + + document test { + field document_field type tensor(d0[2]) { + indexing: attribute + } + } + + constant my_constant { + file: files/constant.json + type: tensor(d0[2]) + } + + onnx-model my_model { + file: files/ranking_model.onnx + input first_input: attribute(document_field) + input "second/input:0": constant(my_constant) + input "third_input": my_function + output "path/to/output:0": out + } + + onnx-model another_model { + file: files/ranking_model.onnx + input first_input: attribute(document_field) + input "second/input:0": constant(my_constant) + input "third_input": another_function + output "path/to/output:2": out + } + + rank-profile test_model_config { + function my_function() { + expression: tensor(d0[2])(1) + } + first-phase { + expression: onnxModel(my_model).out + } + } + + rank-profile test_generated_model_config inherits test_model_config { + function first_input() { + expression: attribute(document_field) + } + function second_input() { + expression: constant(my_constant) + } + function third_input() { + expression: my_function() + } + first-phase { + expression: onnxModel("files/ranking_model.onnx", "path/to/output:1") + } + } + + rank-profile test_summary_features { + function another_function() { + expression: tensor(d0[2])(2) + } + first-phase { + expression: 1 + } + summary-features { + onnxModel(another_model).out + onnxModel("files/ranking_model.onnx", "path/to/output:2") + } + + } + +} diff --git a/config-model/src/test/integration/onnx-file/services.xml b/config-model/src/test/integration/onnx-model/services.xml index 892ce9a9f89..892ce9a9f89 100644 --- a/config-model/src/test/integration/onnx-file/services.xml +++ b/config-model/src/test/integration/onnx-model/services.xml diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java index 7e129410b37..d9b0c70dfdd 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java @@ -1,7 +1,6 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; -import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.model.VespaModel; @@ -16,36 +15,70 @@ public class RankingExpressionWithOnnxModelTestCase { @Test public void testOnnxModelFeature() { - VespaModel model = new VespaModelCreatorWithFilePkg("src/test/integration/onnx-file").create(); + VespaModel model = new VespaModelCreatorWithFilePkg("src/test/integration/onnx-model").create(); DocumentDatabase db = ((IndexedSearchCluster)model.getSearchClusters().get(0)).getDocumentDbs().get(0); - - String modelName = OnnxModelTransformer.toModelName("files/simple.onnx"); - - // Ranking expression should be transformed from - // onnxModel("files/simple.onnx", "output") - // to - // onnxModel(files_simple_onnx).output - - assertTransformedFeature(db, modelName); - assertGeneratedConfig(db, modelName); + assertTransformedFeature(db); + assertGeneratedConfig(db); } - private void assertGeneratedConfig(DocumentDatabase db, String modelName) { + private void assertGeneratedConfig(DocumentDatabase db) { OnnxModelsConfig.Builder builder = new OnnxModelsConfig.Builder(); ((OnnxModelsConfig.Producer) db).getConfig(builder); OnnxModelsConfig config = new OnnxModelsConfig(builder); - assertEquals(1, config.model().size()); - assertEquals(modelName, config.model(0).name()); + assertEquals(3, config.model().size()); + + assertEquals("my_model", config.model(1).name()); + assertEquals(3, config.model(1).input().size()); + assertEquals("first_input", config.model(1).input(0).name()); + assertEquals("attribute(document_field)", config.model(1).input(0).source()); + assertEquals("second/input:0", config.model(1).input(1).name()); + assertEquals("constant(my_constant)", config.model(1).input(1).source()); + assertEquals("third_input", config.model(1).input(2).name()); + assertEquals("rankingExpression(my_function)", config.model(1).input(2).source()); + assertEquals(1, config.model(1).output().size()); + assertEquals("path/to/output:0", config.model(1).output(0).name()); + assertEquals("out", config.model(1).output(0).as()); + + assertEquals("files_ranking_model_onnx", config.model(0).name()); + assertEquals(0, config.model(0).input().size()); + assertEquals(2, config.model(0).output().size()); + assertEquals("path/to/output:1", config.model(0).output(0).name()); + assertEquals("path_to_output_1", config.model(0).output(0).as()); + assertEquals("path/to/output:2", config.model(0).output(1).name()); + assertEquals("path_to_output_2", config.model(0).output(1).as()); + + assertEquals("another_model", config.model(2).name()); + assertEquals("third_input", config.model(2).input(2).name()); + assertEquals("rankingExpression(another_function)", config.model(2).input(2).source()); } - private void assertTransformedFeature(DocumentDatabase db, String modelName) { + private void assertTransformedFeature(DocumentDatabase db) { RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder(); ((RankProfilesConfig.Producer) db).getConfig(builder); RankProfilesConfig config = new RankProfilesConfig(builder); - assertEquals(3, config.rankprofile().size()); - assertEquals("my_profile", config.rankprofile(2).name()); - assertEquals("vespa.rank.firstphase", config.rankprofile(2).fef().property(0).name()); - assertEquals("onnxModel(" + modelName + ").output", config.rankprofile(2).fef().property(0).value()); + assertEquals(5, config.rankprofile().size()); + + assertEquals("test_model_config", config.rankprofile(2).name()); + assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(2).fef().property(0).name()); + assertEquals("vespa.rank.firstphase", config.rankprofile(2).fef().property(2).name()); + assertEquals("onnxModel(my_model).out", config.rankprofile(2).fef().property(2).value()); + + assertEquals("test_generated_model_config", config.rankprofile(3).name()); + assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(3).fef().property(0).name()); + assertEquals("rankingExpression(first_input).rankingScript", config.rankprofile(3).fef().property(2).name()); + assertEquals("rankingExpression(second_input).rankingScript", config.rankprofile(3).fef().property(4).name()); + assertEquals("rankingExpression(third_input).rankingScript", config.rankprofile(3).fef().property(6).name()); + assertEquals("vespa.rank.firstphase", config.rankprofile(3).fef().property(8).name()); + assertEquals("onnxModel(files_ranking_model_onnx).path_to_output_1", config.rankprofile(3).fef().property(8).value()); + + assertEquals("test_summary_features", config.rankprofile(4).name()); + assertEquals("rankingExpression(another_function).rankingScript", config.rankprofile(4).fef().property(0).name()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(4).fef().property(3).name()); + assertEquals("1", config.rankprofile(4).fef().property(3).value()); + assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(4).name()); + assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(4).value()); + assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(5).name()); + assertEquals("onnxModel(files_ranking_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(5).value()); } } |