summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java141
1 files changed, 59 insertions, 82 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
index 69cdae10e47..e1ad003e5bd 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
@@ -1,36 +1,20 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.expressiontransforms;
-import com.yahoo.path.Path;
import com.yahoo.searchdefinition.ImmutableSearch;
import com.yahoo.searchdefinition.OnnxModel;
-import com.yahoo.searchdefinition.RankProfile;
-import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
-import com.yahoo.vespa.model.ml.ConvertedModel;
-import com.yahoo.vespa.model.ml.FeatureArguments;
-import com.yahoo.vespa.model.ml.ModelName;
import java.util.List;
/**
- * Transforms ONNX model features of the forms:
- *
- * onnxModel(config_name)
- * onnxModel(config_name).output
- * onnxModel("path/to/model")
- * onnxModel("path/to/model").output
- * onnxModel("path/to/model", "path/to/output")
- * onnxModel("path/to/model", "unused", "path/to/output") // signature is unused
- *
- * To the format expected by the backend:
- *
- * onnxModel(config_name).output
+ * Transforms instances of the onnxModel ranking feature and generates
+ * ONNX configuration if necessary.
*
* @author lesters
*/
@@ -49,92 +33,85 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
if (context.rankProfile() == null) return feature;
if (context.rankProfile().getSearch() == null) return feature;
- return transformFeature(feature, context.rankProfile());
+ return transformFeature(feature, context.rankProfile().getSearch());
}
- public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile rankProfile) {
- ImmutableSearch search = rankProfile.getSearch();
- final String featureName = feature.getName();
- if ( ! featureName.equals("onnxModel")) return feature;
+ public static ReferenceNode transformFeature(ReferenceNode feature, ImmutableSearch search) {
+ if (!feature.getName().equals("onnxModel")) return feature;
Arguments arguments = feature.getArguments();
if (arguments.isEmpty())
- throw new IllegalArgumentException("An " + featureName + " feature must take an argument referring to a " +
- "onnx-model config or an ONNX file.");
- if (arguments.expressions().size() > 3)
- throw new IllegalArgumentException("An " + featureName + " feature can have at most 3 arguments.");
-
- // Check that the model configuration "onnx-model" exists. If not defined, it should have been added
- // by the "OnnxModelConfigGenerator" processor. If it still doesn't exist, it is because we can't find
- // the actual ONNX file, which can happen if we are restarting or upgrading an application using an
- // ONNX file that was transformed to Vespa ranking expressions. We then assume it is in the model store.
-
- String modelConfigName = getModelConfigName(feature.reference());
- OnnxModel onnxModel = search.onnxModels().get(modelConfigName);
- if (onnxModel == null) {
+ throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a " +
+ "onnx-model config or a ONNX file.");
+ if (arguments.expressions().size() > 2)
+ throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments.");
+
+ // Validation that the file actually exists is handled when the file is added to file distribution.
+ // Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator.
+
+ String modelConfigName;
+ OnnxModel onnxModel;
+ if (arguments.expressions().get(0) instanceof ReferenceNode) {
+ modelConfigName = arguments.expressions().get(0).toString();
+ onnxModel = search.onnxModels().get(modelConfigName);
+ if (onnxModel == null) {
+ throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found");
+ }
+ } else if (arguments.expressions().get(0) instanceof ConstantNode) {
String path = asString(arguments.expressions().get(0));
- ModelName modelName = new ModelName(null, Path.fromString(path), true);
- ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, path, rankProfile);
- FeatureArguments featureArguments = new FeatureArguments(arguments);
- return convertedModel.expression(featureArguments, null);
- }
-
- String defaultOutput = onnxModel.getOutputMap().get(onnxModel.getDefaultOutput());
- String output = getModelOutput(feature.reference(), defaultOutput);
- if (! onnxModel.getOutputMap().containsValue(output)) {
- throw new IllegalArgumentException(featureName + " argument '" + output +
- "' output not found in model '" + onnxModel.getFileName() + "'");
+ modelConfigName = asValidIdentifier(path);
+ onnxModel = search.onnxModels().get(modelConfigName);
+ if (onnxModel == null) {
+ onnxModel = new OnnxModel(modelConfigName, path);
+ search.onnxModels().add(onnxModel);
+ }
+ } else {
+ throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'");
}
- return new ReferenceNode("onnxModel", List.of(new ReferenceNode(modelConfigName)), output);
- }
- public static String getModelConfigName(Reference reference) {
- if (reference.arguments().size() > 0) {
- ExpressionNode expr = reference.arguments().expressions().get(0);
- if (expr instanceof ReferenceNode) { // refers to onnx-model config
- return expr.toString();
+ String output = null;
+ if (feature.getOutput() != null) {
+ output = feature.getOutput();
+ if ( ! hasOutputMapping(onnxModel, output)) {
+ onnxModel.addOutputNameMapping(output, output);
}
- if (expr instanceof ConstantNode) { // refers to an file path
- return asValidIdentifier(expr);
+ } else if (arguments.expressions().size() > 1) {
+ String name = asString(arguments.expressions().get(1));
+ output = asValidIdentifier(name);
+ if ( ! hasOutputMapping(onnxModel, output)) {
+ onnxModel.addOutputNameMapping(name, output);
}
}
- return null;
- }
- public static String getModelOutput(Reference reference, String defaultOutput) {
- if (reference.output() != null) {
- return reference.output();
- } else if (reference.arguments().expressions().size() == 2) {
- return asValidIdentifier(reference.arguments().expressions().get(1));
- } else if (reference.arguments().expressions().size() > 2) {
- return asValidIdentifier(reference.arguments().expressions().get(2));
- }
- return defaultOutput;
+ // Replace feature with name of config
+ ExpressionNode argument = new ReferenceNode(modelConfigName);
+ return new ReferenceNode("onnxModel", List.of(argument), output);
+
}
- public static String stripQuotes(String s) {
- if (isNotQuoteSign(s.codePointAt(0))) return s;
- if (isNotQuoteSign(s.codePointAt(s.length() - 1)))
- throw new IllegalArgumentException("argument [" + s + "] is missing end quote");
- return s.substring(1, s.length()-1);
+ private static boolean hasOutputMapping(OnnxModel onnxModel, String as) {
+ return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as));
}
- public static String asValidIdentifier(String str) {
- return str.replaceAll("[^\\w\\d\\$@_]", "_");
+ private static String asString(ExpressionNode node) {
+ if ( ! (node instanceof ConstantNode))
+ throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
+ return stripQuotes(((ConstantNode)node).sourceString());
}
- private static String asValidIdentifier(ExpressionNode node) {
- return asValidIdentifier(asString(node));
+ private static String stripQuotes(String s) {
+ if ( ! isQuoteSign(s.codePointAt(0))) return s;
+ if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
+ throw new IllegalArgumentException("argument [" + s + "] is missing endquote");
+ return s.substring(1, s.length()-1);
}
- private static boolean isNotQuoteSign(int c) {
- return c != '\'' && c != '"';
+ private static boolean isQuoteSign(int c) {
+ return c == '\'' || c == '"';
}
- public static String asString(ExpressionNode node) {
- if ( ! (node instanceof ConstantNode))
- throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
- return stripQuotes(((ConstantNode)node).sourceString());
+ private static String asValidIdentifier(String str) {
+ return str.replaceAll("[^\\w\\d\\$@_]", "_");
}
}