diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-12-01 09:43:36 -0800 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-12-01 09:43:36 -0800 |
commit | c8c79c3363ad1149fa137f6b4899dad8369b309c (patch) | |
tree | 33ab5b4000042082a455a9c0f31fee0b0e97576e /searchlib | |
parent | fc3a9e518eb0e5904609028e9e388d35ddc61db0 (diff) |
Load using SavedModelBundle
This is to be able to access saved variables without reverse engineering
the 'proprietary binary format' *eye-roll* used to save variables.
Diffstat (limited to 'searchlib')
3 files changed, 19 insertions, 18 deletions
diff --git a/searchlib/pom.xml b/searchlib/pom.xml index 8e15e0d425c..8cb4ed8b0e3 100644 --- a/searchlib/pom.xml +++ b/searchlib/pom.xml @@ -44,6 +44,11 @@ <version>1.4.0</version> </dependency> <dependency> + <groupId>org.tensorflow</groupId> + <artifactId>tensorflow</artifactId> + <version>1.4.0</version> + </dependency> + <dependency> <groupId>com.fasterxml.jackson.core</groupId> <artifactId>jackson-core</artifactId> <scope>test</scope> diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 38d9483a162..91a0f863a14 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -1,13 +1,12 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; -import com.google.protobuf.TextFormat; -import com.yahoo.io.IOUtils; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.yolean.Exceptions; +import org.tensorflow.SavedModelBundle; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.MetaGraphDef; import org.tensorflow.framework.NodeDef; @@ -26,13 +25,13 @@ import java.util.stream.Collectors; /** * Converts a saved TensorFlow model into a ranking expression and set of constants. - * + * * @author bratseth */ public class TensorFlowImporter { private final OperationMapper operationMapper = new OperationMapper(); - + /** * Imports a saved TensorFlow model from a directory. * The model should be saved as a pbtxt file. @@ -40,17 +39,14 @@ public class TensorFlowImporter { */ public List<RankingExpression> importModel(String modelDir, MessageLogger logger) { try { - SavedModel.Builder builder = SavedModel.newBuilder(); - TextFormat.getParser().merge(IOUtils.createReader(modelDir + "/saved_model.pbtxt"), builder); - return importModel(builder.build(), logger); - - // TODO: Support binary reading: - //SavedModel.parseFrom(new FileInputStream(modelDir + "/saved_model.pbtxt")); + SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); + return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), logger); + } catch (IOException e) { throw new IllegalArgumentException("Could not open TensorFlow model directory '" + modelDir + "'", e); } - + } /** Import all declared inputs in all the graphs in the given model */ @@ -84,11 +80,11 @@ public class TensorFlowImporter { private Map<String, TensorType> importInputs(Map<String, TensorInfo> inputInfoMap) { Map<String, TensorType> inputs = new HashMap<>(); - inputInfoMap.forEach((key, value) -> inputs.put(nameOf(value.getName()), + inputInfoMap.forEach((key, value) -> inputs.put(nameOf(value.getName()), importTensorType(value.getTensorShape()))); return inputs; } - + static TensorType importTensorType(TensorShapeProto tensorShape) { TensorType.Builder b = new TensorType.Builder(); for (int i = 0; i < tensorShape.getDimCount(); i++) { @@ -110,7 +106,7 @@ public class TensorFlowImporter { private TypedTensorFunction importNode(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph, String indent) { return tensorFunctionOf(tfNode, inputs, graph, indent); } - + private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph, @@ -126,13 +122,13 @@ public class TensorFlowImporter { default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported"); } } - + private List<TypedTensorFunction> importArguments(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph, String indent) { return tfNode.getInputList().stream() .map(argNode -> importNode(getNode(nameOf(argNode), graph), inputs, graph, indent + " ")) .collect(Collectors.toList()); } - + private NodeDef getNode(String name, GraphDef graph) { return graph.getNodeList().stream() .filter(node -> node.getName().equals(name)) diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java index c780b3d0c7d..9b53b3824e2 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java @@ -23,9 +23,9 @@ public class TensorFlowImporterTestCase { // Check logged messages assertEquals(2, logger.messages.size()); - assertEquals("Skipping output 'index_to_string_Lookup:0' of signature 'tensorflow/serving/classify': Conversion of TensorFlow operation 'LookupTableFindV2' is not supported", - logger.messages.get(0)); assertEquals("Skipping output 'TopKV2:0' of signature 'tensorflow/serving/classify': Conversion of TensorFlow operation 'TopKV2' is not supported", + logger.messages.get(0)); + assertEquals("Skipping output 'index_to_string_Lookup:0' of signature 'tensorflow/serving/classify': Conversion of TensorFlow operation 'LookupTableFindV2' is not supported", logger.messages.get(1)); // Check resulting Vespa expression |