diff options
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java | 19 |
1 files changed, 5 insertions, 14 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index cf72a327ba0..4a6551adca7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -1,11 +1,9 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.ScalarFunctions; -import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.yolean.Exceptions; import org.tensorflow.SavedModelBundle; import org.tensorflow.framework.GraphDef; @@ -64,7 +62,8 @@ public class TensorFlowImporter { for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) { String outputName = output.getKey(); try { - ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, result); + NodeDef node = getNode(nameOf(output.getValue().getName()), graph.getGraphDef()); + importNode(node, graph.getGraphDef(), model, result); signature.output(outputName, nameOf(output.getValue().getName())); } catch (IllegalArgumentException e) { @@ -98,21 +97,13 @@ public class TensorFlowImporter { return b.build(); } - private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, ImportResult result) { - return importNode(nameOf(output.getName()), graph, model, result); - } - - private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, ImportResult result) { - TensorFunction function = importNode(getNode(nodeName, graph), graph, model, result).function(); - result.expression(nodeName, new RankingExpression(nodeName, new TensorFunctionNode(function))); - return new TensorFunctionNode(function); // wrap top level (only) as an expression // TODO: waht to return - } - /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result); + // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output + // will be used result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), new TensorFunctionNode(function.function()))); - return function; // TODO: waht to return + return function; } private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { |