diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java | 141 |
1 files changed, 59 insertions, 82 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java index 69cdae10e47..e1ad003e5bd 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java @@ -1,36 +1,20 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; -import com.yahoo.path.Path; import com.yahoo.searchdefinition.ImmutableSearch; import com.yahoo.searchdefinition.OnnxModel; -import com.yahoo.searchdefinition.RankProfile; -import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; -import com.yahoo.vespa.model.ml.ConvertedModel; -import com.yahoo.vespa.model.ml.FeatureArguments; -import com.yahoo.vespa.model.ml.ModelName; import java.util.List; /** - * Transforms ONNX model features of the forms: - * - * onnxModel(config_name) - * onnxModel(config_name).output - * onnxModel("path/to/model") - * onnxModel("path/to/model").output - * onnxModel("path/to/model", "path/to/output") - * onnxModel("path/to/model", "unused", "path/to/output") // signature is unused - * - * To the format expected by the backend: - * - * onnxModel(config_name).output + * Transforms instances of the onnxModel ranking feature and generates + * ONNX configuration if necessary. * * @author lesters */ @@ -49,92 +33,85 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { if (context.rankProfile() == null) return feature; if (context.rankProfile().getSearch() == null) return feature; - return transformFeature(feature, context.rankProfile()); + return transformFeature(feature, context.rankProfile().getSearch()); } - public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile rankProfile) { - ImmutableSearch search = rankProfile.getSearch(); - final String featureName = feature.getName(); - if ( ! featureName.equals("onnxModel")) return feature; + public static ReferenceNode transformFeature(ReferenceNode feature, ImmutableSearch search) { + if (!feature.getName().equals("onnxModel")) return feature; Arguments arguments = feature.getArguments(); if (arguments.isEmpty()) - throw new IllegalArgumentException("An " + featureName + " feature must take an argument referring to a " + - "onnx-model config or an ONNX file."); - if (arguments.expressions().size() > 3) - throw new IllegalArgumentException("An " + featureName + " feature can have at most 3 arguments."); - - // Check that the model configuration "onnx-model" exists. If not defined, it should have been added - // by the "OnnxModelConfigGenerator" processor. If it still doesn't exist, it is because we can't find - // the actual ONNX file, which can happen if we are restarting or upgrading an application using an - // ONNX file that was transformed to Vespa ranking expressions. We then assume it is in the model store. - - String modelConfigName = getModelConfigName(feature.reference()); - OnnxModel onnxModel = search.onnxModels().get(modelConfigName); - if (onnxModel == null) { + throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a " + + "onnx-model config or a ONNX file."); + if (arguments.expressions().size() > 2) + throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments."); + + // Validation that the file actually exists is handled when the file is added to file distribution. + // Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator. + + String modelConfigName; + OnnxModel onnxModel; + if (arguments.expressions().get(0) instanceof ReferenceNode) { + modelConfigName = arguments.expressions().get(0).toString(); + onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { + throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found"); + } + } else if (arguments.expressions().get(0) instanceof ConstantNode) { String path = asString(arguments.expressions().get(0)); - ModelName modelName = new ModelName(null, Path.fromString(path), true); - ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, path, rankProfile); - FeatureArguments featureArguments = new FeatureArguments(arguments); - return convertedModel.expression(featureArguments, null); - } - - String defaultOutput = onnxModel.getOutputMap().get(onnxModel.getDefaultOutput()); - String output = getModelOutput(feature.reference(), defaultOutput); - if (! onnxModel.getOutputMap().containsValue(output)) { - throw new IllegalArgumentException(featureName + " argument '" + output + - "' output not found in model '" + onnxModel.getFileName() + "'"); + modelConfigName = asValidIdentifier(path); + onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { + onnxModel = new OnnxModel(modelConfigName, path); + search.onnxModels().add(onnxModel); + } + } else { + throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'"); } - return new ReferenceNode("onnxModel", List.of(new ReferenceNode(modelConfigName)), output); - } - public static String getModelConfigName(Reference reference) { - if (reference.arguments().size() > 0) { - ExpressionNode expr = reference.arguments().expressions().get(0); - if (expr instanceof ReferenceNode) { // refers to onnx-model config - return expr.toString(); + String output = null; + if (feature.getOutput() != null) { + output = feature.getOutput(); + if ( ! hasOutputMapping(onnxModel, output)) { + onnxModel.addOutputNameMapping(output, output); } - if (expr instanceof ConstantNode) { // refers to an file path - return asValidIdentifier(expr); + } else if (arguments.expressions().size() > 1) { + String name = asString(arguments.expressions().get(1)); + output = asValidIdentifier(name); + if ( ! hasOutputMapping(onnxModel, output)) { + onnxModel.addOutputNameMapping(name, output); } } - return null; - } - public static String getModelOutput(Reference reference, String defaultOutput) { - if (reference.output() != null) { - return reference.output(); - } else if (reference.arguments().expressions().size() == 2) { - return asValidIdentifier(reference.arguments().expressions().get(1)); - } else if (reference.arguments().expressions().size() > 2) { - return asValidIdentifier(reference.arguments().expressions().get(2)); - } - return defaultOutput; + // Replace feature with name of config + ExpressionNode argument = new ReferenceNode(modelConfigName); + return new ReferenceNode("onnxModel", List.of(argument), output); + } - public static String stripQuotes(String s) { - if (isNotQuoteSign(s.codePointAt(0))) return s; - if (isNotQuoteSign(s.codePointAt(s.length() - 1))) - throw new IllegalArgumentException("argument [" + s + "] is missing end quote"); - return s.substring(1, s.length()-1); + private static boolean hasOutputMapping(OnnxModel onnxModel, String as) { + return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as)); } - public static String asValidIdentifier(String str) { - return str.replaceAll("[^\\w\\d\\$@_]", "_"); + private static String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); } - private static String asValidIdentifier(ExpressionNode node) { - return asValidIdentifier(asString(node)); + private static String stripQuotes(String s) { + if ( ! isQuoteSign(s.codePointAt(0))) return s; + if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) + throw new IllegalArgumentException("argument [" + s + "] is missing endquote"); + return s.substring(1, s.length()-1); } - private static boolean isNotQuoteSign(int c) { - return c != '\'' && c != '"'; + private static boolean isQuoteSign(int c) { + return c == '\'' || c == '"'; } - public static String asString(ExpressionNode node) { - if ( ! (node instanceof ConstantNode)) - throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); - return stripQuotes(((ConstantNode)node).sourceString()); + private static String asValidIdentifier(String str) { + return str.replaceAll("[^\\w\\d\\$@_]", "_"); } } |