summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-12-20 16:07:16 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-12-20 16:07:16 +0100
commitd97e6b0a72a7c7d38365475637bff3897bc7597a (patch)
tree900703a5c105bfc805e11971757572a97a0bae3e /searchlib
parent82ff058a65bb3ae4b195906a8f927d00846b7fcc (diff)
Simplify
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java19
1 files changed, 5 insertions, 14 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index cf72a327ba0..4a6551adca7 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -1,11 +1,9 @@
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.ScalarFunctions;
-import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.yolean.Exceptions;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.framework.GraphDef;
@@ -64,7 +62,8 @@ public class TensorFlowImporter {
for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) {
String outputName = output.getKey();
try {
- ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, result);
+ NodeDef node = getNode(nameOf(output.getValue().getName()), graph.getGraphDef());
+ importNode(node, graph.getGraphDef(), model, result);
signature.output(outputName, nameOf(output.getValue().getName()));
}
catch (IllegalArgumentException e) {
@@ -98,21 +97,13 @@ public class TensorFlowImporter {
return b.build();
}
- private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, ImportResult result) {
- return importNode(nameOf(output.getName()), graph, model, result);
- }
-
- private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, ImportResult result) {
- TensorFunction function = importNode(getNode(nodeName, graph), graph, model, result).function();
- result.expression(nodeName, new RankingExpression(nodeName, new TensorFunctionNode(function)));
- return new TensorFunctionNode(function); // wrap top level (only) as an expression // TODO: waht to return
- }
-
/** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result);
+ // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output
+ // will be used
result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), new TensorFunctionNode(function.function())));
- return function; // TODO: waht to return
+ return function;
}
private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {