summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pom.xml2
-rw-r--r--searchlib/pom.xml9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java134
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java15
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java95
5 files changed, 254 insertions, 1 deletions
diff --git a/pom.xml b/pom.xml
index c6e7168904a..9acd65f1a54 100644
--- a/pom.xml
+++ b/pom.xml
@@ -441,7 +441,7 @@
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
- <version>2.4.1</version>
+ <version>3.4.0</version>
</dependency>
<dependency>
<groupId>com.googlecode.jmockit</groupId>
diff --git a/searchlib/pom.xml b/searchlib/pom.xml
index 36e6fa1ffda..bb305f460ca 100644
--- a/searchlib/pom.xml
+++ b/searchlib/pom.xml
@@ -34,6 +34,15 @@
<artifactId>vespajlib</artifactId>
<version>${project.version}</version>
</dependency>
+ <dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>proto</artifactId>
+ <version>1.4.0</version>
+ </dependency>
</dependencies>
<build>
<plugins>
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
new file mode 100644
index 00000000000..160af794faf
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -0,0 +1,134 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.google.protobuf.ProtocolStringList;
+import com.google.protobuf.TextFormat;
+import com.yahoo.io.IOUtils;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Matmul;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.yolean.Exceptions;
+import org.tensorflow.framework.GraphDef;
+import org.tensorflow.framework.MetaGraphDef;
+import org.tensorflow.framework.NodeDef;
+import org.tensorflow.framework.OpDef;
+import org.tensorflow.framework.SavedModel;
+import org.tensorflow.framework.SignatureDef;
+import org.tensorflow.framework.TensorInfo;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * Converts a saved TensorFlow model into a ranking expression and set of constants.
+ *
+ * @author bratseth
+ */
+public class TensorFlowImporter {
+
+ /**
+ * Imports a saved TensorFlow model from a directory.
+ * The model should be saved as a pbtxt file.
+ * The name of the model is taken at the pbtxt file name (not including the .pbtxt ending).
+ */
+ public void importModel(String modelDir) {
+ try {
+ SavedModel.Builder builder = SavedModel.newBuilder();
+ TextFormat.getParser().merge(IOUtils.createReader(modelDir + "/saved_model.pbtxt"), builder);
+ //System.out.println("Read " + builder);
+ importModel(builder.build());
+
+ // TODO: Support binary reading:
+ //SavedModel.parseFrom(new FileInputStream(modelDir + "/saved_model.pbtxt"));
+ }
+ catch (IOException e) {
+ throw new IllegalArgumentException("Could not open TensorFlow model directory '" + modelDir + "'", e);
+ }
+
+ }
+
+ private void importModel(SavedModel model) {
+ model.getMetaGraphsList().forEach(this::importGraph);
+ }
+
+ private void importGraph(MetaGraphDef graph) {
+ System.out.println("Importing graph");
+ for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
+ System.out.println(" Importing signature def " + signatureEntry.getKey() +
+ " with method name " + signatureEntry.getValue().getMethodName());
+ signatureEntry.getValue().getOutputsMap().values()
+ .forEach(output -> importOutput(output, signatureEntry.getValue().getMethodName(), graph.getGraphDef()));
+ }
+ }
+
+ private void importOutput(TensorInfo output, String signatureName, GraphDef graph) {
+ try {
+ System.out.println(" Importing output " + output.getName());
+ NodeDef node = getNode(nameOf(output.getName()), graph);
+ // System.out.println("Ops:-------------");
+ // graph.getStrippedOpList().getOpList().stream().forEach(s -> System.out.println(s.getName()));
+ // System.out.println("-----------------");
+ importNode(node, graph, "");
+ }
+ catch (IllegalArgumentException e) {
+ System.err.println("Skipping output '" + output.getName() + "' of signature '" + signatureName + "': " + Exceptions.toMessageString(e));
+ }
+ }
+
+ private ExpressionNode importNode(NodeDef tfNode, GraphDef graph, String indent) {
+ System.out.println(" " + indent + "Importing node " + tfNode.getName());
+ List<ExpressionNode> arguments = new ArrayList<>();
+ for (String input : tfNode.getInputList())
+ arguments.add(importNode(getNode(nameOf(input), graph), graph, indent + " "));
+ ExpressionNode node = expressionNodeOf(tfNode.getName(), arguments);
+ }
+
+ private ExpressionNode expressionNodeOf(String node, List<ExpressionNode> arguments) {
+ return new TensorFunctionNode(tensorFunctionOf(node, arguments.stream()
+ .map(TensorFunctionNode.TensorFunctionExpressionNode::new)
+ .collect(Collectors.toList())));
+ }
+
+ private TensorFunction tensorFunctionOf(String node, List<TensorFunction> arguments) {
+ switch (node) {
+ case "add" : return new Join(arguments.get(0), arguments.get(1), ScalarFunctions.add());
+ case "MatMul" : return new Matmul(arguments.get(0), arguments.get(1), ScalarFunctions.add());
+ }
+ }
+
+ private NodeDef getNode(String name, GraphDef graph) {
+ return graph.getNodeList().stream()
+ .filter(node -> node.getName().equals(name))
+ .findFirst()
+ .orElseThrow(() -> new IllegalArgumentException("Could not find node '" + name + "'"));
+ }
+
+ private void importOp(OpDef op, MetaGraphDef.MetaInfoDef graph) {
+ System.out.println(" Importing op " + op.getName());
+ }
+
+ private OpDef getOp(String name, MetaGraphDef.MetaInfoDef graph) {
+ return graph.getStrippedOpList().getOpList().stream()
+ .filter(op -> op.getName().equals(name))
+ .findFirst()
+ .orElseThrow(() -> new IllegalArgumentException("Could not find operation '" + name + "'"));
+ }
+
+ /**
+ * An output has the form name:index.
+ * This returns the name part without the index.
+ */
+ private String nameOf(String outputName) {
+ return outputName.split(":")[0];
+ }
+
+ private boolean contains(String string, ProtocolStringList strings) {
+ return strings.asByteStringList().stream().anyMatch(s -> s.toStringUtf8().equals(string));
+ }
+
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
new file mode 100644
index 00000000000..4c511047118
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
@@ -0,0 +1,15 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import org.junit.Test;
+
+/**
+ * @author bratseth
+ */
+public class TensorFlowImporterTestCase {
+
+ @Test
+ public void testModel1() {
+ new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/");
+ }
+
+}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java
new file mode 100644
index 00000000000..6606e278102
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java
@@ -0,0 +1,95 @@
+package com.yahoo.tensor.functions;
+
+import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author bratseth
+ */
+public class MatmulTestCase {
+
+ @Test
+ public void testMatmul2d() {
+ // Convention: a is the 'outermost' dimension, etc.
+ Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[3])"));
+ ab.cell( 1,0, 0);
+ ab.cell( 2,0, 1);
+ ab.cell( 3,0, 2);
+ ab.cell( 4,1, 0);
+ ab.cell( 5,1, 1);
+ ab.cell( 6,1, 2);
+ Tensor a = ab.build();
+
+ Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[3],b[2])"));
+ bb.cell( 7,0, 0);
+ bb.cell( 8,0, 1);
+ bb.cell( 9,1, 0);
+ bb.cell(10,1, 1);
+ bb.cell(11,2, 0);
+ bb.cell(12,2, 1);
+ Tensor b = bb.build();
+
+ Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],c[2])"));
+ rb.cell( 58,0, 0);
+ rb.cell( 64,0, 1);
+ rb.cell(139,1, 0);
+ rb.cell(154,1, 1);
+ Tensor r = rb.build();
+
+ Tensor result = a.matmul(b.rename(ImmutableList.of("a","b"),ImmutableList.of("b","c")), "b");
+ assertEquals(r, result);
+ }
+
+ @Test
+ public void testMatmul3d() {
+ // Convention: a is the 'outermost' dimension, etc.
+ Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[2],c[3])"));
+ ab.cell( 1,0, 0, 0);
+ ab.cell( 2,0, 0, 1);
+ ab.cell( 3,0, 0, 2);
+ ab.cell( 4,0, 1, 0);
+ ab.cell( 5,0, 1, 1);
+ ab.cell( 6,0, 1, 2);
+ ab.cell( 7,1, 0, 0);
+ ab.cell( 8,1, 0, 1);
+ ab.cell( 9,1, 0, 2);
+ ab.cell(10,1, 1, 0);
+ ab.cell(11,1, 1, 1);
+ ab.cell(12,1, 1, 2);
+ Tensor a = ab.build();
+
+ Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[3],c[2])"));
+ bb.cell(13,0, 0, 0);
+ bb.cell(14,0, 0, 1);
+ bb.cell(15,0, 1, 0);
+ bb.cell(16,0, 1, 1);
+ bb.cell(17,0, 2, 0);
+ bb.cell(18,0, 2, 1);
+ bb.cell(19,1, 0, 0);
+ bb.cell(20,1, 0, 1);
+ bb.cell(21,1, 1, 0);
+ bb.cell(22,1, 1, 1);
+ bb.cell(23,1, 2, 0);
+ bb.cell(24,1, 2, 1);
+ Tensor b = bb.build();
+
+ Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[2],d[2])"));
+ rb.cell( 94,0, 0, 0);
+ rb.cell(100,0, 0, 1);
+ rb.cell(229,0, 1, 0);
+ rb.cell(244,0, 1, 1);
+ rb.cell(508,1, 0, 0);
+ rb.cell(532,1, 0, 1);
+ rb.cell(697,1, 1, 0);
+ rb.cell(730,1, 1, 1);
+ Tensor r = rb.build();
+
+ Tensor result = a.matmul(b.rename(ImmutableList.of("b","c"),ImmutableList.of("c","d")), "c");
+ System.out.println(result);
+ }
+
+}