diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java | 65 |
1 files changed, 52 insertions, 13 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java index d8ffbd7d030..e1ad003e5bd 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java @@ -1,6 +1,7 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; +import com.yahoo.searchdefinition.ImmutableSearch; import com.yahoo.searchdefinition.OnnxModel; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; @@ -12,9 +13,8 @@ import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import java.util.List; /** - * Transforms instances of the onnxModel(model-path, output) ranking feature - * by adding the model file to file distribution and rewriting this feature - * to point to the generated configuration. + * Transforms instances of the onnxModel ranking feature and generates + * ONNX configuration if necessary. * * @author lesters */ @@ -31,27 +31,66 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans } private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { + if (context.rankProfile() == null) return feature; + if (context.rankProfile().getSearch() == null) return feature; + return transformFeature(feature, context.rankProfile().getSearch()); + } + + public static ReferenceNode transformFeature(ReferenceNode feature, ImmutableSearch search) { if (!feature.getName().equals("onnxModel")) return feature; Arguments arguments = feature.getArguments(); if (arguments.isEmpty()) - throw new IllegalArgumentException("An onnxModel feature must take an argument pointing to the ONNX file."); + throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a " + + "onnx-model config or a ONNX file."); if (arguments.expressions().size() > 2) throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments."); - String path = asString(arguments.expressions().get(0)); - String name = toModelName(path); - String output = arguments.expressions().size() > 1 ? asString(arguments.expressions().get(1)) : null; - // Validation that the file actually exists is handled when the file is added to file distribution. // Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator. - // Add model to config - context.rankProfile().getSearch().onnxModels().add(new OnnxModel(name, path)); + String modelConfigName; + OnnxModel onnxModel; + if (arguments.expressions().get(0) instanceof ReferenceNode) { + modelConfigName = arguments.expressions().get(0).toString(); + onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { + throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found"); + } + } else if (arguments.expressions().get(0) instanceof ConstantNode) { + String path = asString(arguments.expressions().get(0)); + modelConfigName = asValidIdentifier(path); + onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { + onnxModel = new OnnxModel(modelConfigName, path); + search.onnxModels().add(onnxModel); + } + } else { + throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'"); + } + + String output = null; + if (feature.getOutput() != null) { + output = feature.getOutput(); + if ( ! hasOutputMapping(onnxModel, output)) { + onnxModel.addOutputNameMapping(output, output); + } + } else if (arguments.expressions().size() > 1) { + String name = asString(arguments.expressions().get(1)); + output = asValidIdentifier(name); + if ( ! hasOutputMapping(onnxModel, output)) { + onnxModel.addOutputNameMapping(name, output); + } + } // Replace feature with name of config - ExpressionNode argument = new ReferenceNode(name); + ExpressionNode argument = new ReferenceNode(modelConfigName); return new ReferenceNode("onnxModel", List.of(argument), output); + + } + + private static boolean hasOutputMapping(OnnxModel onnxModel, String as) { + return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as)); } private static String asString(ExpressionNode node) { @@ -71,8 +110,8 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans return c == '\'' || c == '"'; } - public static String toModelName(String path) { - return path.replaceAll("[^\\w\\d\\$@_]", "_"); + private static String asValidIdentifier(String str) { + return str.replaceAll("[^\\w\\d\\$@_]", "_"); } } |