summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java65
1 files changed, 52 insertions, 13 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
index d8ffbd7d030..e1ad003e5bd 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
@@ -1,6 +1,7 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.expressiontransforms;
+import com.yahoo.searchdefinition.ImmutableSearch;
import com.yahoo.searchdefinition.OnnxModel;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
@@ -12,9 +13,8 @@ import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import java.util.List;
/**
- * Transforms instances of the onnxModel(model-path, output) ranking feature
- * by adding the model file to file distribution and rewriting this feature
- * to point to the generated configuration.
+ * Transforms instances of the onnxModel ranking feature and generates
+ * ONNX configuration if necessary.
*
* @author lesters
*/
@@ -31,27 +31,66 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans
}
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
+ if (context.rankProfile() == null) return feature;
+ if (context.rankProfile().getSearch() == null) return feature;
+ return transformFeature(feature, context.rankProfile().getSearch());
+ }
+
+ public static ReferenceNode transformFeature(ReferenceNode feature, ImmutableSearch search) {
if (!feature.getName().equals("onnxModel")) return feature;
Arguments arguments = feature.getArguments();
if (arguments.isEmpty())
- throw new IllegalArgumentException("An onnxModel feature must take an argument pointing to the ONNX file.");
+ throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a " +
+ "onnx-model config or a ONNX file.");
if (arguments.expressions().size() > 2)
throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments.");
- String path = asString(arguments.expressions().get(0));
- String name = toModelName(path);
- String output = arguments.expressions().size() > 1 ? asString(arguments.expressions().get(1)) : null;
-
// Validation that the file actually exists is handled when the file is added to file distribution.
// Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator.
- // Add model to config
- context.rankProfile().getSearch().onnxModels().add(new OnnxModel(name, path));
+ String modelConfigName;
+ OnnxModel onnxModel;
+ if (arguments.expressions().get(0) instanceof ReferenceNode) {
+ modelConfigName = arguments.expressions().get(0).toString();
+ onnxModel = search.onnxModels().get(modelConfigName);
+ if (onnxModel == null) {
+ throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found");
+ }
+ } else if (arguments.expressions().get(0) instanceof ConstantNode) {
+ String path = asString(arguments.expressions().get(0));
+ modelConfigName = asValidIdentifier(path);
+ onnxModel = search.onnxModels().get(modelConfigName);
+ if (onnxModel == null) {
+ onnxModel = new OnnxModel(modelConfigName, path);
+ search.onnxModels().add(onnxModel);
+ }
+ } else {
+ throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'");
+ }
+
+ String output = null;
+ if (feature.getOutput() != null) {
+ output = feature.getOutput();
+ if ( ! hasOutputMapping(onnxModel, output)) {
+ onnxModel.addOutputNameMapping(output, output);
+ }
+ } else if (arguments.expressions().size() > 1) {
+ String name = asString(arguments.expressions().get(1));
+ output = asValidIdentifier(name);
+ if ( ! hasOutputMapping(onnxModel, output)) {
+ onnxModel.addOutputNameMapping(name, output);
+ }
+ }
// Replace feature with name of config
- ExpressionNode argument = new ReferenceNode(name);
+ ExpressionNode argument = new ReferenceNode(modelConfigName);
return new ReferenceNode("onnxModel", List.of(argument), output);
+
+ }
+
+ private static boolean hasOutputMapping(OnnxModel onnxModel, String as) {
+ return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as));
}
private static String asString(ExpressionNode node) {
@@ -71,8 +110,8 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans
return c == '\'' || c == '"';
}
- public static String toModelName(String path) {
- return path.replaceAll("[^\\w\\d\\$@_]", "_");
+ private static String asValidIdentifier(String str) {
+ return str.replaceAll("[^\\w\\d\\$@_]", "_");
}
}