diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-05-19 12:03:06 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-05-19 12:03:06 +0200 |
commit | 5c24dc5c9642a8d9ed70aee4c950fd0678a1ebec (patch) | |
tree | bd9b74bf00c832456f0b83c1b2cd7010be387d68 /config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxFeatureConverter.java | |
parent | f17c4fe7de4c55f5c4ee61897eab8c2f588d8405 (diff) |
Rename the 'searchdefinition' package to 'schema'
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxFeatureConverter.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxFeatureConverter.java | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxFeatureConverter.java new file mode 100644 index 00000000000..2277491cd47 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxFeatureConverter.java @@ -0,0 +1,64 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.schema.expressiontransforms; + +import com.yahoo.path.Path; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.vespa.model.ml.ConvertedModel; +import com.yahoo.vespa.model.ml.FeatureArguments; + +import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.Map; + +/** + * Replaces instances of the onnx(model-path, output) + * pseudofeature with the native Vespa ranking expression implementing + * the same computation. + * + * @author bratseth + * @author lesters + */ +public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { + + /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ + private final Map<Path, ConvertedModel> convertedOnnxModels = new HashMap<>(); + + @Override + public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { + if (node instanceof ReferenceNode) + return transformFeature((ReferenceNode) node, context); + else if (node instanceof CompositeNode) + return super.transformChildren((CompositeNode) node, context); + else + return node; + } + + private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { + if ( ! feature.getName().equals("onnx_vespa")) return feature; + try { + FeatureArguments arguments = asFeatureArguments(feature.getArguments()); + ConvertedModel convertedModel = + convertedOnnxModels.computeIfAbsent(arguments.path(), + path -> ConvertedModel.fromSourceOrStore(path, true, context)); + return convertedModel.expression(arguments, context); + } + catch (IllegalArgumentException | UncheckedIOException e) { + throw new IllegalArgumentException("Could not use Onnx model from " + feature, e); + } + } + + private FeatureArguments asFeatureArguments(Arguments arguments) { + if (arguments.isEmpty()) + throw new IllegalArgumentException("An ONNX node must take an argument pointing to " + + "the ONNX model file under [application]/models"); + if (arguments.expressions().size() > 3) + throw new IllegalArgumentException("An onnx feature can have at most 3 arguments"); + + return new FeatureArguments(arguments); + } + +} |