diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-11-07 16:01:10 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-11-07 16:01:10 +0100 |
commit | d68ea53b1b88f4f0720f10dc94e694f2ed4bb542 (patch) | |
tree | 8aee94e382332c28f4d92a64d67a8449f272420b | |
parent | dfcd99818b4b0151ad226df03548915f9dddd9fb (diff) |
TF model translation WIP
5 files changed, 254 insertions, 1 deletions
@@ -441,7 +441,7 @@ <dependency> <groupId>com.google.protobuf</groupId> <artifactId>protobuf-java</artifactId> - <version>2.4.1</version> + <version>3.4.0</version> </dependency> <dependency> <groupId>com.googlecode.jmockit</groupId> diff --git a/searchlib/pom.xml b/searchlib/pom.xml index 36e6fa1ffda..bb305f460ca 100644 --- a/searchlib/pom.xml +++ b/searchlib/pom.xml @@ -34,6 +34,15 @@ <artifactId>vespajlib</artifactId> <version>${project.version}</version> </dependency> + <dependency> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + </dependency> + <dependency> + <groupId>org.tensorflow</groupId> + <artifactId>proto</artifactId> + <version>1.4.0</version> + </dependency> </dependencies> <build> <plugins> diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java new file mode 100644 index 00000000000..160af794faf --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -0,0 +1,134 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.google.protobuf.ProtocolStringList; +import com.google.protobuf.TextFormat; +import com.yahoo.io.IOUtils; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Matmul; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.yolean.Exceptions; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.MetaGraphDef; +import org.tensorflow.framework.NodeDef; +import org.tensorflow.framework.OpDef; +import org.tensorflow.framework.SavedModel; +import org.tensorflow.framework.SignatureDef; +import org.tensorflow.framework.TensorInfo; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Converts a saved TensorFlow model into a ranking expression and set of constants. + * + * @author bratseth + */ +public class TensorFlowImporter { + + /** + * Imports a saved TensorFlow model from a directory. + * The model should be saved as a pbtxt file. + * The name of the model is taken at the pbtxt file name (not including the .pbtxt ending). + */ + public void importModel(String modelDir) { + try { + SavedModel.Builder builder = SavedModel.newBuilder(); + TextFormat.getParser().merge(IOUtils.createReader(modelDir + "/saved_model.pbtxt"), builder); + //System.out.println("Read " + builder); + importModel(builder.build()); + + // TODO: Support binary reading: + //SavedModel.parseFrom(new FileInputStream(modelDir + "/saved_model.pbtxt")); + } + catch (IOException e) { + throw new IllegalArgumentException("Could not open TensorFlow model directory '" + modelDir + "'", e); + } + + } + + private void importModel(SavedModel model) { + model.getMetaGraphsList().forEach(this::importGraph); + } + + private void importGraph(MetaGraphDef graph) { + System.out.println("Importing graph"); + for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) { + System.out.println(" Importing signature def " + signatureEntry.getKey() + + " with method name " + signatureEntry.getValue().getMethodName()); + signatureEntry.getValue().getOutputsMap().values() + .forEach(output -> importOutput(output, signatureEntry.getValue().getMethodName(), graph.getGraphDef())); + } + } + + private void importOutput(TensorInfo output, String signatureName, GraphDef graph) { + try { + System.out.println(" Importing output " + output.getName()); + NodeDef node = getNode(nameOf(output.getName()), graph); + // System.out.println("Ops:-------------"); + // graph.getStrippedOpList().getOpList().stream().forEach(s -> System.out.println(s.getName())); + // System.out.println("-----------------"); + importNode(node, graph, ""); + } + catch (IllegalArgumentException e) { + System.err.println("Skipping output '" + output.getName() + "' of signature '" + signatureName + "': " + Exceptions.toMessageString(e)); + } + } + + private ExpressionNode importNode(NodeDef tfNode, GraphDef graph, String indent) { + System.out.println(" " + indent + "Importing node " + tfNode.getName()); + List<ExpressionNode> arguments = new ArrayList<>(); + for (String input : tfNode.getInputList()) + arguments.add(importNode(getNode(nameOf(input), graph), graph, indent + " ")); + ExpressionNode node = expressionNodeOf(tfNode.getName(), arguments); + } + + private ExpressionNode expressionNodeOf(String node, List<ExpressionNode> arguments) { + return new TensorFunctionNode(tensorFunctionOf(node, arguments.stream() + .map(TensorFunctionNode.TensorFunctionExpressionNode::new) + .collect(Collectors.toList()))); + } + + private TensorFunction tensorFunctionOf(String node, List<TensorFunction> arguments) { + switch (node) { + case "add" : return new Join(arguments.get(0), arguments.get(1), ScalarFunctions.add()); + case "MatMul" : return new Matmul(arguments.get(0), arguments.get(1), ScalarFunctions.add()); + } + } + + private NodeDef getNode(String name, GraphDef graph) { + return graph.getNodeList().stream() + .filter(node -> node.getName().equals(name)) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException("Could not find node '" + name + "'")); + } + + private void importOp(OpDef op, MetaGraphDef.MetaInfoDef graph) { + System.out.println(" Importing op " + op.getName()); + } + + private OpDef getOp(String name, MetaGraphDef.MetaInfoDef graph) { + return graph.getStrippedOpList().getOpList().stream() + .filter(op -> op.getName().equals(name)) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException("Could not find operation '" + name + "'")); + } + + /** + * An output has the form name:index. + * This returns the name part without the index. + */ + private String nameOf(String outputName) { + return outputName.split(":")[0]; + } + + private boolean contains(String string, ProtocolStringList strings) { + return strings.asByteStringList().stream().anyMatch(s -> s.toStringUtf8().equals(string)); + } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java new file mode 100644 index 00000000000..4c511047118 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java @@ -0,0 +1,15 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import org.junit.Test; + +/** + * @author bratseth + */ +public class TensorFlowImporterTestCase { + + @Test + public void testModel1() { + new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/"); + } + +} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java new file mode 100644 index 00000000000..6606e278102 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java @@ -0,0 +1,95 @@ +package com.yahoo.tensor.functions; + +import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class MatmulTestCase { + + @Test + public void testMatmul2d() { + // Convention: a is the 'outermost' dimension, etc. + Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[3])")); + ab.cell( 1,0, 0); + ab.cell( 2,0, 1); + ab.cell( 3,0, 2); + ab.cell( 4,1, 0); + ab.cell( 5,1, 1); + ab.cell( 6,1, 2); + Tensor a = ab.build(); + + Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[3],b[2])")); + bb.cell( 7,0, 0); + bb.cell( 8,0, 1); + bb.cell( 9,1, 0); + bb.cell(10,1, 1); + bb.cell(11,2, 0); + bb.cell(12,2, 1); + Tensor b = bb.build(); + + Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],c[2])")); + rb.cell( 58,0, 0); + rb.cell( 64,0, 1); + rb.cell(139,1, 0); + rb.cell(154,1, 1); + Tensor r = rb.build(); + + Tensor result = a.matmul(b.rename(ImmutableList.of("a","b"),ImmutableList.of("b","c")), "b"); + assertEquals(r, result); + } + + @Test + public void testMatmul3d() { + // Convention: a is the 'outermost' dimension, etc. + Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[2],c[3])")); + ab.cell( 1,0, 0, 0); + ab.cell( 2,0, 0, 1); + ab.cell( 3,0, 0, 2); + ab.cell( 4,0, 1, 0); + ab.cell( 5,0, 1, 1); + ab.cell( 6,0, 1, 2); + ab.cell( 7,1, 0, 0); + ab.cell( 8,1, 0, 1); + ab.cell( 9,1, 0, 2); + ab.cell(10,1, 1, 0); + ab.cell(11,1, 1, 1); + ab.cell(12,1, 1, 2); + Tensor a = ab.build(); + + Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[3],c[2])")); + bb.cell(13,0, 0, 0); + bb.cell(14,0, 0, 1); + bb.cell(15,0, 1, 0); + bb.cell(16,0, 1, 1); + bb.cell(17,0, 2, 0); + bb.cell(18,0, 2, 1); + bb.cell(19,1, 0, 0); + bb.cell(20,1, 0, 1); + bb.cell(21,1, 1, 0); + bb.cell(22,1, 1, 1); + bb.cell(23,1, 2, 0); + bb.cell(24,1, 2, 1); + Tensor b = bb.build(); + + Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[2],d[2])")); + rb.cell( 94,0, 0, 0); + rb.cell(100,0, 0, 1); + rb.cell(229,0, 1, 0); + rb.cell(244,0, 1, 1); + rb.cell(508,1, 0, 0); + rb.cell(532,1, 0, 1); + rb.cell(697,1, 1, 0); + rb.cell(730,1, 1, 1); + Tensor r = rb.build(); + + Tensor result = a.matmul(b.rename(ImmutableList.of("b","c"),ImmutableList.of("c","d")), "c"); + System.out.println(result); + } + +} |