diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-05-19 12:03:06 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-05-19 12:03:06 +0200 |
commit | 5c24dc5c9642a8d9ed70aee4c950fd0678a1ebec (patch) | |
tree | bd9b74bf00c832456f0b83c1b2cd7010be387d68 /config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java | |
parent | f17c4fe7de4c55f5c4ee61897eab8c2f588d8405 (diff) |
Rename the 'searchdefinition' package to 'schema'
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java | 66 |
1 files changed, 0 insertions, 66 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java deleted file mode 100644 index 7b165d94cae..00000000000 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchdefinition.expressiontransforms; - -import com.yahoo.path.Path; -import com.yahoo.searchlib.rankingexpression.rule.Arguments; -import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; -import com.yahoo.vespa.model.ml.ConvertedModel; -import com.yahoo.vespa.model.ml.FeatureArguments; - -import java.io.UncheckedIOException; -import java.util.HashMap; -import java.util.Map; - -/** - * Replaces instances of the tensorflow(model-path, signature, output) - * pseudofeature with the native Vespa ranking expression implementing - * the same computation. - * - * @author bratseth - */ -public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { - - /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ - private final Map<Path, ConvertedModel> convertedTensorFlowModels = new HashMap<>(); - - public TensorFlowFeatureConverter() {} - - @Override - public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { - if (node instanceof ReferenceNode) - return transformFeature((ReferenceNode) node, context); - else if (node instanceof CompositeNode) - return super.transformChildren((CompositeNode) node, context); - else - return node; - } - - private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { - if ( ! feature.getName().equals("tensorflow")) return feature; - - try { - FeatureArguments arguments = asFeatureArguments(feature.getArguments()); - ConvertedModel convertedModel = - convertedTensorFlowModels.computeIfAbsent(arguments.path(), - path -> ConvertedModel.fromSourceOrStore(path, false, context)); - return convertedModel.expression(arguments, context); - } - catch (IllegalArgumentException | UncheckedIOException e) { - throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e); - } - } - - private FeatureArguments asFeatureArguments(Arguments arguments) { - if (arguments.isEmpty()) - throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + - "the tensorflow model directory under [application]/models"); - if (arguments.expressions().size() > 3) - throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments"); - - return new FeatureArguments(arguments); - } - -} |