summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2020-09-20 11:08:33 +0200
committerGitHub <noreply@github.com>2020-09-20 11:08:33 +0200
commit2c193d74d00dd3c3fa90b347ec77fcea828cce2f (patch)
tree20d98bb2ed1c7ad91e753eedfe91fffd7d4850f0
parenta8c10d0114c7157a34b82776d6c45aaf3e440147 (diff)
parentbe544696d4b70ee186dc80f250bda7d99cd0e20f (diff)
Merge pull request #14450 from vespa-engine/lesters/explicit-onnx-config
Add explicit config for onnx models
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java33
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java12
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java48
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java65
-rw-r--r--config-model/src/main/javacc/SDParser.jj61
-rw-r--r--config-model/src/test/integration/onnx-file/files/simple.onnx23
-rw-r--r--config-model/src/test/integration/onnx-file/searchdefinitions/test.sd11
-rw-r--r--config-model/src/test/integration/onnx-model/files/constant.json6
-rw-r--r--config-model/src/test/integration/onnx-model/searchdefinitions/test.sd70
-rw-r--r--config-model/src/test/integration/onnx-model/services.xml (renamed from config-model/src/test/integration/onnx-file/services.xml)0
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java73
12 files changed, 329 insertions, 77 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
index b7b18887dd8..c2fb2107604 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
@@ -5,7 +5,10 @@ import com.yahoo.config.FileReference;
import com.yahoo.vespa.model.AbstractService;
import com.yahoo.vespa.model.utils.FileSender;
+import java.util.ArrayList;
import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
import java.util.Objects;
/**
@@ -20,6 +23,8 @@ public class OnnxModel {
private final String name;
private String path = null;
private String fileReference = "";
+ private List<OnnxNameMapping> inputMap = new ArrayList<>();
+ private List<OnnxNameMapping> outputMap = new ArrayList<>();
public PathType getPathType() {
return pathType;
@@ -49,6 +54,18 @@ public class OnnxModel {
this.pathType = PathType.URI;
}
+ public void addInputNameMapping(String onnxName, String vespaName) {
+ Objects.requireNonNull(onnxName, "Onnx name cannot be null");
+ Objects.requireNonNull(vespaName, "Vespa name cannot be null");
+ this.inputMap.add(new OnnxNameMapping(onnxName, vespaName));
+ }
+
+ public void addOutputNameMapping(String onnxName, String vespaName) {
+ Objects.requireNonNull(onnxName, "Onnx name cannot be null");
+ Objects.requireNonNull(vespaName, "Vespa name cannot be null");
+ this.outputMap.add(new OnnxNameMapping(onnxName, vespaName));
+ }
+
/** Initiate sending of this constant to some services over file distribution */
public void sendTo(Collection<? extends AbstractService> services) {
FileReference reference = (pathType == OnnxModel.PathType.FILE)
@@ -62,6 +79,9 @@ public class OnnxModel {
public String getUri() { return path; }
public String getFileReference() { return fileReference; }
+ public List<OnnxNameMapping> getInputMap() { return Collections.unmodifiableList(inputMap); }
+ public List<OnnxNameMapping> getOutputMap() { return Collections.unmodifiableList(outputMap); }
+
public void validate() {
if (path == null || path.isEmpty())
throw new IllegalArgumentException("ONNX models must have a file or uri.");
@@ -76,4 +96,17 @@ public class OnnxModel {
return b.toString();
}
+ public static class OnnxNameMapping {
+ private String onnxName;
+ private String vespaName;
+
+ private OnnxNameMapping(String onnxName, String vespaName) {
+ this.onnxName = onnxName;
+ this.vespaName = vespaName;
+ }
+ public String getOnnxName() { return onnxName; }
+ public String getVespaName() { return vespaName; }
+ public void setVespaName(String vespaName) { this.vespaName = vespaName; }
+ }
+
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
index 87663ac79a3..1cc33664e8c 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
@@ -27,6 +27,10 @@ public class OnnxModels {
return models.get(name);
}
+ public boolean has(String name) {
+ return models.containsKey(name);
+ }
+
public Map<String, OnnxModel> asMap() {
return Collections.unmodifiableMap(models);
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index 00076c84532..84442fedc48 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -122,10 +122,14 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
for (OnnxModel model : onnxModels.asMap().values()) {
if ("".equals(model.getFileReference()))
log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way
- else
- builder.model(new OnnxModelsConfig.Model.Builder()
- .name(model.getName())
- .fileref(model.getFileReference()));
+ else {
+ OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder();
+ modelBuilder.name(model.getName());
+ modelBuilder.fileref(model.getFileReference());
+ model.getInputMap().forEach(mapper -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(mapper.getOnnxName()).source(mapper.getVespaName())));
+ model.getOutputMap().forEach(mapper -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(mapper.getOnnxName()).as(mapper.getVespaName())));
+ builder.model(modelBuilder);
+ }
}
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index c3c10139684..87eaaf0387a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -8,8 +8,10 @@ import com.yahoo.compress.Compressor;
import com.yahoo.config.model.api.ModelContext;
import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.OnnxModel;
import com.yahoo.searchdefinition.document.RankType;
import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
@@ -20,6 +22,7 @@ import com.yahoo.vespa.config.search.RankProfilesConfig;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
+import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
@@ -37,10 +40,10 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
/** A reusable compressor with default settings */
private static final Compressor compressor = new Compressor();
-
+
private final String keyEndMarker = "\r=";
private final String valueEndMarker = "\r\n";
-
+
// TODO: These are to expose coupling between the strings used here and elsewhere
public final static String summaryFeatureFefPropertyPrefix = "vespa.summary.feature";
public final static String rankFeatureFefPropertyPrefix = "vespa.dump.feature";
@@ -63,7 +66,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, AttributeFields attributeFields) {
this(rankProfile, queryProfiles, importedModels, attributeFields, new TestProperties());
}
-
+
private Compressor.Compression compress(List<Pair<String, String>> properties) {
StringBuilder b = new StringBuilder();
for (Pair<String, String> property : properties)
@@ -109,12 +112,12 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
b.fef(fefB);
}
- /**
+ /**
* Returns the properties of this as an unmodifiable list.
* Note: This method is expensive.
*/
public List<Pair<String, String>> configProperties() { return decompress(compressedProperties); }
-
+
private static class Deriver {
/**
@@ -194,6 +197,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
ignoreDefaultRankFeatures = rankProfile.getIgnoreDefaultRankFeatures();
rankProperties = new ArrayList<>(rankProfile.getRankProperties());
derivePropertiesAndSummaryFeaturesFromFunctions(rankProfile.getFunctions());
+ deriveOnnxModelFunctionsAndSummaryFeatures(rankProfile);
}
private void derivePropertiesAndSummaryFeaturesFromFunctions(Map<String, RankProfile.RankingExpressionFunction> functions) {
@@ -433,6 +437,40 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
return properties;
}
+ private void deriveOnnxModelFunctionsAndSummaryFeatures(RankProfile rankProfile) {
+ if (rankProfile.getSearch() == null) return;
+ if (rankProfile.getSearch().onnxModels().asMap().isEmpty()) return;
+ replaceOnnxFunctionInputs(rankProfile);
+ replaceImplicitOnnxConfigSummaryFeatures(rankProfile);
+ }
+
+ private void replaceOnnxFunctionInputs(RankProfile rankProfile) {
+ Set<String> functionNames = rankProfile.getFunctions().keySet();
+ if (functionNames.isEmpty()) return;
+ for (OnnxModel onnxModel: rankProfile.getSearch().onnxModels().asMap().values()) {
+ for (OnnxModel.OnnxNameMapping mapping : onnxModel.getInputMap()) {
+ String source = mapping.getVespaName();
+ if (functionNames.contains(source)) {
+ mapping.setVespaName("rankingExpression(" + source + ")");
+ }
+ }
+ }
+ }
+
+ private void replaceImplicitOnnxConfigSummaryFeatures(RankProfile rankProfile) {
+ if (summaryFeatures == null || summaryFeatures.isEmpty()) return;
+ Set<ReferenceNode> replacedSummaryFeatures = new HashSet<>();
+ for (Iterator<ReferenceNode> i = summaryFeatures.iterator(); i.hasNext(); ) {
+ ReferenceNode referenceNode = i.next();
+ ReferenceNode replacedNode = OnnxModelTransformer.transformFeature(referenceNode, rankProfile.getSearch());
+ if (referenceNode != replacedNode) {
+ replacedSummaryFeatures.add(replacedNode);
+ i.remove();
+ }
+ }
+ summaryFeatures.addAll(replacedSummaryFeatures);
+ }
+
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
index d8ffbd7d030..e1ad003e5bd 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
@@ -1,6 +1,7 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.expressiontransforms;
+import com.yahoo.searchdefinition.ImmutableSearch;
import com.yahoo.searchdefinition.OnnxModel;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
@@ -12,9 +13,8 @@ import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import java.util.List;
/**
- * Transforms instances of the onnxModel(model-path, output) ranking feature
- * by adding the model file to file distribution and rewriting this feature
- * to point to the generated configuration.
+ * Transforms instances of the onnxModel ranking feature and generates
+ * ONNX configuration if necessary.
*
* @author lesters
*/
@@ -31,27 +31,66 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans
}
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
+ if (context.rankProfile() == null) return feature;
+ if (context.rankProfile().getSearch() == null) return feature;
+ return transformFeature(feature, context.rankProfile().getSearch());
+ }
+
+ public static ReferenceNode transformFeature(ReferenceNode feature, ImmutableSearch search) {
if (!feature.getName().equals("onnxModel")) return feature;
Arguments arguments = feature.getArguments();
if (arguments.isEmpty())
- throw new IllegalArgumentException("An onnxModel feature must take an argument pointing to the ONNX file.");
+ throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a " +
+ "onnx-model config or a ONNX file.");
if (arguments.expressions().size() > 2)
throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments.");
- String path = asString(arguments.expressions().get(0));
- String name = toModelName(path);
- String output = arguments.expressions().size() > 1 ? asString(arguments.expressions().get(1)) : null;
-
// Validation that the file actually exists is handled when the file is added to file distribution.
// Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator.
- // Add model to config
- context.rankProfile().getSearch().onnxModels().add(new OnnxModel(name, path));
+ String modelConfigName;
+ OnnxModel onnxModel;
+ if (arguments.expressions().get(0) instanceof ReferenceNode) {
+ modelConfigName = arguments.expressions().get(0).toString();
+ onnxModel = search.onnxModels().get(modelConfigName);
+ if (onnxModel == null) {
+ throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found");
+ }
+ } else if (arguments.expressions().get(0) instanceof ConstantNode) {
+ String path = asString(arguments.expressions().get(0));
+ modelConfigName = asValidIdentifier(path);
+ onnxModel = search.onnxModels().get(modelConfigName);
+ if (onnxModel == null) {
+ onnxModel = new OnnxModel(modelConfigName, path);
+ search.onnxModels().add(onnxModel);
+ }
+ } else {
+ throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'");
+ }
+
+ String output = null;
+ if (feature.getOutput() != null) {
+ output = feature.getOutput();
+ if ( ! hasOutputMapping(onnxModel, output)) {
+ onnxModel.addOutputNameMapping(output, output);
+ }
+ } else if (arguments.expressions().size() > 1) {
+ String name = asString(arguments.expressions().get(1));
+ output = asValidIdentifier(name);
+ if ( ! hasOutputMapping(onnxModel, output)) {
+ onnxModel.addOutputNameMapping(name, output);
+ }
+ }
// Replace feature with name of config
- ExpressionNode argument = new ReferenceNode(name);
+ ExpressionNode argument = new ReferenceNode(modelConfigName);
return new ReferenceNode("onnxModel", List.of(argument), output);
+
+ }
+
+ private static boolean hasOutputMapping(OnnxModel onnxModel, String as) {
+ return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as));
}
private static String asString(ExpressionNode node) {
@@ -71,8 +110,8 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans
return c == '\'' || c == '"';
}
- public static String toModelName(String path) {
- return path.replaceAll("[^\\w\\d\\$@_]", "_");
+ private static String asValidIdentifier(String str) {
+ return str.replaceAll("[^\\w\\d\\$@_]", "_");
}
}
diff --git a/config-model/src/main/javacc/SDParser.jj b/config-model/src/main/javacc/SDParser.jj
index ad359a6a943..bf752b39fa8 100644
--- a/config-model/src/main/javacc/SDParser.jj
+++ b/config-model/src/main/javacc/SDParser.jj
@@ -32,6 +32,7 @@ import com.yahoo.searchdefinition.document.*;
import com.yahoo.searchdefinition.document.annotation.SDAnnotationType;
import com.yahoo.searchdefinition.document.annotation.TemporaryAnnotationReferenceDataType;
import com.yahoo.searchdefinition.RankingConstant;
+import com.yahoo.searchdefinition.OnnxModel;
import com.yahoo.searchdefinition.Index;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.DocumentsOnlyRankProfile;
@@ -231,6 +232,7 @@ TOKEN :
| < SUBSTRING: "substring" >
| < SUFFIX: "suffix" >
| < CONSTANT: "constant">
+| < ONNXMODEL: "onnx-model">
| < RANKPROFILE: "rank-profile" >
| < RANKDEGRADATIONFREQ: "rank-degradation-frequency" >
| < RANKDEGRADATION: "rank-degradation" >
@@ -379,6 +381,8 @@ TOKEN :
| < LESSTHAN: "<" >
| < GREATERTHAN: ">" >
| < VARIABLE: "$" <IDENTIFIER> >
+| < ONNX_INPUT_SL: "input" (" ")* (<IDENTIFIER>|<QUOTEDSTRING>) (" ")* ":" (" ")* (~["\n"])* ("\n")? >
+| < ONNX_OUTPUT_SL: "output" (" ")* (<IDENTIFIER>|<QUOTEDSTRING>) (" ")* ":" (" ")* (~["\n"])* ("\n")? >
}
// Declare a special skip token for comments.
@@ -451,7 +455,8 @@ Object rootSchemaItem(Search search) : { }
| structOutside(search)
| annotationOutside(search)
| fieldSet(search)
- | importField(search) )
+ | importField(search)
+ | onnxModel(search) )
{ return null; }
}
@@ -1847,6 +1852,60 @@ void hnswIndexBody(HnswIndexParams.Builder params) :
}
/**
+ * Consumes a onnx-model block of a search element.
+ *
+ * @param search The search object to add content to.
+ */
+void onnxModel(Search search) :
+{
+ String name;
+ OnnxModel onnxModel;
+}
+{
+ ( <ONNXMODEL> name = identifier()
+ {
+ onnxModel = new OnnxModel(name);
+ }
+ lbrace() (onnxModelItem(onnxModel) (<NL>)*)+ <RBRACE> )
+ {
+ if (documentsOnly) return;
+ search.onnxModels().add(onnxModel);
+ }
+}
+
+/**
+ * This rule consumes an onnx-model block.
+ *
+ * @param onnxModel The onnxModel to modify.
+ * @return Null.
+ */
+Object onnxModelItem(OnnxModel onnxModel) :
+{
+ String path = null;
+}
+{
+ (
+ (<FILE> <COLON> path = filePath() { } (<NL>)*) { onnxModel.setFileName(path); } |
+ (<URI> <COLON> path = uriPath() { } (<NL>)*) { onnxModel.setUri(path); } |
+ (<ONNX_INPUT_SL>) {
+ String name = token.image.substring(5, token.image.lastIndexOf(":")).trim();
+ if (name.startsWith("\"")) { name = name.substring(1, name.length() - 1); }
+ String source = token.image.substring(token.image.lastIndexOf(":") + 1).trim();
+ onnxModel.addInputNameMapping(name, source);
+ } |
+ (<ONNX_OUTPUT_SL>) {
+ String name = token.image.substring(6, token.image.lastIndexOf(":")).trim();
+ if (name.startsWith("\"")) { name = name.substring(1, name.length() - 1); }
+ String as = token.image.substring(token.image.lastIndexOf(":") + 1).trim();
+ onnxModel.addOutputNameMapping(name, as);
+ }
+ )
+ {
+ return null;
+ }
+}
+
+/**
* Consumes a constant block of a search element.
*
* @param search The search object to add content to.
diff --git a/config-model/src/test/integration/onnx-file/files/simple.onnx b/config-model/src/test/integration/onnx-file/files/simple.onnx
deleted file mode 100644
index eaa66f533da..00000000000
--- a/config-model/src/test/integration/onnx-file/files/simple.onnx
+++ /dev/null
@@ -1,23 +0,0 @@
- simple.py:ß
-0
- query_tensor
-attribute_tensormatmul"MatMul
-"
-matmul
- bias_tensoroutput"Addsimple_scoringZ
- query_tensor
- 
-
-Z"
-attribute_tensor
- 
-
-Z
- bias_tensor
-
-
-b
-output
- 
-
-B \ No newline at end of file
diff --git a/config-model/src/test/integration/onnx-file/searchdefinitions/test.sd b/config-model/src/test/integration/onnx-file/searchdefinitions/test.sd
deleted file mode 100644
index 5ca0cd1b8bf..00000000000
--- a/config-model/src/test/integration/onnx-file/searchdefinitions/test.sd
+++ /dev/null
@@ -1,11 +0,0 @@
-# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-search test {
- document test {}
-
- rank-profile my_profile inherits default {
- first-phase {
- expression: onnxModel("files/simple.onnx", "output")
- }
- }
-
-}
diff --git a/config-model/src/test/integration/onnx-model/files/constant.json b/config-model/src/test/integration/onnx-model/files/constant.json
new file mode 100644
index 00000000000..63f64a73af5
--- /dev/null
+++ b/config-model/src/test/integration/onnx-model/files/constant.json
@@ -0,0 +1,6 @@
+{
+ "cells": [
+ { "address": { "d0": "0" }, "value": 2.0 },
+ { "address": { "d0": "1" }, "value": 3.0 }
+ ]
+} \ No newline at end of file
diff --git a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd
new file mode 100644
index 00000000000..0f0fa694e6f
--- /dev/null
+++ b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd
@@ -0,0 +1,70 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+search test {
+
+ document test {
+ field document_field type tensor(d0[2]) {
+ indexing: attribute
+ }
+ }
+
+ constant my_constant {
+ file: files/constant.json
+ type: tensor(d0[2])
+ }
+
+ onnx-model my_model {
+ file: files/ranking_model.onnx
+ input first_input: attribute(document_field)
+ input "second/input:0": constant(my_constant)
+ input "third_input": my_function
+ output "path/to/output:0": out
+ }
+
+ onnx-model another_model {
+ file: files/ranking_model.onnx
+ input first_input: attribute(document_field)
+ input "second/input:0": constant(my_constant)
+ input "third_input": another_function
+ output "path/to/output:2": out
+ }
+
+ rank-profile test_model_config {
+ function my_function() {
+ expression: tensor(d0[2])(1)
+ }
+ first-phase {
+ expression: onnxModel(my_model).out
+ }
+ }
+
+ rank-profile test_generated_model_config inherits test_model_config {
+ function first_input() {
+ expression: attribute(document_field)
+ }
+ function second_input() {
+ expression: constant(my_constant)
+ }
+ function third_input() {
+ expression: my_function()
+ }
+ first-phase {
+ expression: onnxModel("files/ranking_model.onnx", "path/to/output:1")
+ }
+ }
+
+ rank-profile test_summary_features {
+ function another_function() {
+ expression: tensor(d0[2])(2)
+ }
+ first-phase {
+ expression: 1
+ }
+ summary-features {
+ onnxModel(another_model).out
+ onnxModel("files/ranking_model.onnx", "path/to/output:2")
+ }
+
+ }
+
+}
diff --git a/config-model/src/test/integration/onnx-file/services.xml b/config-model/src/test/integration/onnx-model/services.xml
index 892ce9a9f89..892ce9a9f89 100644
--- a/config-model/src/test/integration/onnx-file/services.xml
+++ b/config-model/src/test/integration/onnx-model/services.xml
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
index 7e129410b37..d9b0c70dfdd 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
@@ -1,7 +1,6 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.processing;
-import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.model.VespaModel;
@@ -16,36 +15,70 @@ public class RankingExpressionWithOnnxModelTestCase {
@Test
public void testOnnxModelFeature() {
- VespaModel model = new VespaModelCreatorWithFilePkg("src/test/integration/onnx-file").create();
+ VespaModel model = new VespaModelCreatorWithFilePkg("src/test/integration/onnx-model").create();
DocumentDatabase db = ((IndexedSearchCluster)model.getSearchClusters().get(0)).getDocumentDbs().get(0);
-
- String modelName = OnnxModelTransformer.toModelName("files/simple.onnx");
-
- // Ranking expression should be transformed from
- // onnxModel("files/simple.onnx", "output")
- // to
- // onnxModel(files_simple_onnx).output
-
- assertTransformedFeature(db, modelName);
- assertGeneratedConfig(db, modelName);
+ assertTransformedFeature(db);
+ assertGeneratedConfig(db);
}
- private void assertGeneratedConfig(DocumentDatabase db, String modelName) {
+ private void assertGeneratedConfig(DocumentDatabase db) {
OnnxModelsConfig.Builder builder = new OnnxModelsConfig.Builder();
((OnnxModelsConfig.Producer) db).getConfig(builder);
OnnxModelsConfig config = new OnnxModelsConfig(builder);
- assertEquals(1, config.model().size());
- assertEquals(modelName, config.model(0).name());
+ assertEquals(3, config.model().size());
+
+ assertEquals("my_model", config.model(1).name());
+ assertEquals(3, config.model(1).input().size());
+ assertEquals("first_input", config.model(1).input(0).name());
+ assertEquals("attribute(document_field)", config.model(1).input(0).source());
+ assertEquals("second/input:0", config.model(1).input(1).name());
+ assertEquals("constant(my_constant)", config.model(1).input(1).source());
+ assertEquals("third_input", config.model(1).input(2).name());
+ assertEquals("rankingExpression(my_function)", config.model(1).input(2).source());
+ assertEquals(1, config.model(1).output().size());
+ assertEquals("path/to/output:0", config.model(1).output(0).name());
+ assertEquals("out", config.model(1).output(0).as());
+
+ assertEquals("files_ranking_model_onnx", config.model(0).name());
+ assertEquals(0, config.model(0).input().size());
+ assertEquals(2, config.model(0).output().size());
+ assertEquals("path/to/output:1", config.model(0).output(0).name());
+ assertEquals("path_to_output_1", config.model(0).output(0).as());
+ assertEquals("path/to/output:2", config.model(0).output(1).name());
+ assertEquals("path_to_output_2", config.model(0).output(1).as());
+
+ assertEquals("another_model", config.model(2).name());
+ assertEquals("third_input", config.model(2).input(2).name());
+ assertEquals("rankingExpression(another_function)", config.model(2).input(2).source());
}
- private void assertTransformedFeature(DocumentDatabase db, String modelName) {
+ private void assertTransformedFeature(DocumentDatabase db) {
RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder();
((RankProfilesConfig.Producer) db).getConfig(builder);
RankProfilesConfig config = new RankProfilesConfig(builder);
- assertEquals(3, config.rankprofile().size());
- assertEquals("my_profile", config.rankprofile(2).name());
- assertEquals("vespa.rank.firstphase", config.rankprofile(2).fef().property(0).name());
- assertEquals("onnxModel(" + modelName + ").output", config.rankprofile(2).fef().property(0).value());
+ assertEquals(5, config.rankprofile().size());
+
+ assertEquals("test_model_config", config.rankprofile(2).name());
+ assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(2).fef().property(0).name());
+ assertEquals("vespa.rank.firstphase", config.rankprofile(2).fef().property(2).name());
+ assertEquals("onnxModel(my_model).out", config.rankprofile(2).fef().property(2).value());
+
+ assertEquals("test_generated_model_config", config.rankprofile(3).name());
+ assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(3).fef().property(0).name());
+ assertEquals("rankingExpression(first_input).rankingScript", config.rankprofile(3).fef().property(2).name());
+ assertEquals("rankingExpression(second_input).rankingScript", config.rankprofile(3).fef().property(4).name());
+ assertEquals("rankingExpression(third_input).rankingScript", config.rankprofile(3).fef().property(6).name());
+ assertEquals("vespa.rank.firstphase", config.rankprofile(3).fef().property(8).name());
+ assertEquals("onnxModel(files_ranking_model_onnx).path_to_output_1", config.rankprofile(3).fef().property(8).value());
+
+ assertEquals("test_summary_features", config.rankprofile(4).name());
+ assertEquals("rankingExpression(another_function).rankingScript", config.rankprofile(4).fef().property(0).name());
+ assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(4).fef().property(3).name());
+ assertEquals("1", config.rankprofile(4).fef().property(3).value());
+ assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(4).name());
+ assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(4).value());
+ assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(5).name());
+ assertEquals("onnxModel(files_ranking_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(5).value());
}
}