summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-12-01 09:43:36 -0800
committerJon Bratseth <bratseth@yahoo-inc.com>2017-12-01 09:43:36 -0800
commitc8c79c3363ad1149fa137f6b4899dad8369b309c (patch)
tree33ab5b4000042082a455a9c0f31fee0b0e97576e /searchlib
parentfc3a9e518eb0e5904609028e9e388d35ddc61db0 (diff)
Load using SavedModelBundle
This is to be able to access saved variables without reverse engineering the 'proprietary binary format' *eye-roll* used to save variables.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/pom.xml5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java28
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java4
3 files changed, 19 insertions, 18 deletions
diff --git a/searchlib/pom.xml b/searchlib/pom.xml
index 8e15e0d425c..8cb4ed8b0e3 100644
--- a/searchlib/pom.xml
+++ b/searchlib/pom.xml
@@ -44,6 +44,11 @@
<version>1.4.0</version>
</dependency>
<dependency>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>tensorflow</artifactId>
+ <version>1.4.0</version>
+ </dependency>
+ <dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<scope>test</scope>
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 38d9483a162..91a0f863a14 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -1,13 +1,12 @@
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-import com.google.protobuf.TextFormat;
-import com.yahoo.io.IOUtils;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.yolean.Exceptions;
+import org.tensorflow.SavedModelBundle;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
@@ -26,13 +25,13 @@ import java.util.stream.Collectors;
/**
* Converts a saved TensorFlow model into a ranking expression and set of constants.
- *
+ *
* @author bratseth
*/
public class TensorFlowImporter {
private final OperationMapper operationMapper = new OperationMapper();
-
+
/**
* Imports a saved TensorFlow model from a directory.
* The model should be saved as a pbtxt file.
@@ -40,17 +39,14 @@ public class TensorFlowImporter {
*/
public List<RankingExpression> importModel(String modelDir, MessageLogger logger) {
try {
- SavedModel.Builder builder = SavedModel.newBuilder();
- TextFormat.getParser().merge(IOUtils.createReader(modelDir + "/saved_model.pbtxt"), builder);
- return importModel(builder.build(), logger);
-
- // TODO: Support binary reading:
- //SavedModel.parseFrom(new FileInputStream(modelDir + "/saved_model.pbtxt"));
+ SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
+ return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), logger);
+
}
catch (IOException e) {
throw new IllegalArgumentException("Could not open TensorFlow model directory '" + modelDir + "'", e);
}
-
+
}
/** Import all declared inputs in all the graphs in the given model */
@@ -84,11 +80,11 @@ public class TensorFlowImporter {
private Map<String, TensorType> importInputs(Map<String, TensorInfo> inputInfoMap) {
Map<String, TensorType> inputs = new HashMap<>();
- inputInfoMap.forEach((key, value) -> inputs.put(nameOf(value.getName()),
+ inputInfoMap.forEach((key, value) -> inputs.put(nameOf(value.getName()),
importTensorType(value.getTensorShape())));
return inputs;
}
-
+
static TensorType importTensorType(TensorShapeProto tensorShape) {
TensorType.Builder b = new TensorType.Builder();
for (int i = 0; i < tensorShape.getDimCount(); i++) {
@@ -110,7 +106,7 @@ public class TensorFlowImporter {
private TypedTensorFunction importNode(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph, String indent) {
return tensorFunctionOf(tfNode, inputs, graph, indent);
}
-
+
private TypedTensorFunction tensorFunctionOf(NodeDef tfNode,
Map<String, TensorType> inputs,
GraphDef graph,
@@ -126,13 +122,13 @@ public class TensorFlowImporter {
default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported");
}
}
-
+
private List<TypedTensorFunction> importArguments(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph, String indent) {
return tfNode.getInputList().stream()
.map(argNode -> importNode(getNode(nameOf(argNode), graph), inputs, graph, indent + " "))
.collect(Collectors.toList());
}
-
+
private NodeDef getNode(String name, GraphDef graph) {
return graph.getNodeList().stream()
.filter(node -> node.getName().equals(name))
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
index c780b3d0c7d..9b53b3824e2 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
@@ -23,9 +23,9 @@ public class TensorFlowImporterTestCase {
// Check logged messages
assertEquals(2, logger.messages.size());
- assertEquals("Skipping output 'index_to_string_Lookup:0' of signature 'tensorflow/serving/classify': Conversion of TensorFlow operation 'LookupTableFindV2' is not supported",
- logger.messages.get(0));
assertEquals("Skipping output 'TopKV2:0' of signature 'tensorflow/serving/classify': Conversion of TensorFlow operation 'TopKV2' is not supported",
+ logger.messages.get(0));
+ assertEquals("Skipping output 'index_to_string_Lookup:0' of signature 'tensorflow/serving/classify': Conversion of TensorFlow operation 'LookupTableFindV2' is not supported",
logger.messages.get(1));
// Check resulting Vespa expression