From b5ffe229474223844c150e99d24ca618e5e9f8dd Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Tue, 28 Nov 2017 21:35:59 +0100 Subject: Complete prototype TensorFlow mapping --- .../integration/tensorflow/TensorFlowImporter.java | 224 +++++++++++++++++---- .../rule/SerializationContext.java | 2 +- .../rankingexpression/rule/TensorFunctionNode.java | 3 + .../tensorflow/TensorFlowImporterTestCase.java | 22 +- 4 files changed, 212 insertions(+), 39 deletions(-) (limited to 'searchlib') diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 160af794faf..8dcd31b270e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -1,15 +1,24 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +import com.google.common.collect.ImmutableList; import com.google.protobuf.ProtocolStringList; import com.google.protobuf.TextFormat; import com.yahoo.io.IOUtils; +import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.VariableTensor; +import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.Matmul; +import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.Softmax; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.yolean.Exceptions; +import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.MetaGraphDef; import org.tensorflow.framework.NodeDef; @@ -17,11 +26,14 @@ import org.tensorflow.framework.OpDef; import org.tensorflow.framework.SavedModel; import org.tensorflow.framework.SignatureDef; import org.tensorflow.framework.TensorInfo; +import org.tensorflow.framework.TensorShapeProto; import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.DoubleBinaryOperator; import java.util.stream.Collectors; /** @@ -31,17 +43,31 @@ import java.util.stream.Collectors; */ public class TensorFlowImporter { + /* + A note on conversion from implicitly numbered to explicitly named dimensions: + Vespa tensor dimensions are explicitly named and thus have an explicit notion of being + 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation + comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation + around dimension renaming operations which mirrors those built into the TF operation definitions. + + To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost' + dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation + and the result is then renamed again (if necessary) to recover this convention across a full nested + computation. + + This requires us to track tensor types throughout the conversion. + */ + /** * Imports a saved TensorFlow model from a directory. * The model should be saved as a pbtxt file. * The name of the model is taken at the pbtxt file name (not including the .pbtxt ending). */ - public void importModel(String modelDir) { + public List importModel(String modelDir) { try { SavedModel.Builder builder = SavedModel.newBuilder(); TextFormat.getParser().merge(IOUtils.createReader(modelDir + "/saved_model.pbtxt"), builder); - //System.out.println("Read " + builder); - importModel(builder.build()); + return importModel(builder.build()); // TODO: Support binary reading: //SavedModel.parseFrom(new FileInputStream(modelDir + "/saved_model.pbtxt")); @@ -52,53 +78,161 @@ public class TensorFlowImporter { } - private void importModel(SavedModel model) { - model.getMetaGraphsList().forEach(this::importGraph); + /** Import all declared inputs in all the graphs in the given model */ + private List importModel(SavedModel model) { + // TODO: Handle name conflicts between output keys in different graphs? + return model.getMetaGraphsList().stream() + .flatMap(graph -> importGraph(graph).stream()) + .collect(Collectors.toList()); } - - private void importGraph(MetaGraphDef graph) { + + private List importGraph(MetaGraphDef graph) { System.out.println("Importing graph"); + List expressions = new ArrayList<>(); for (Map.Entry signatureEntry : graph.getSignatureDefMap().entrySet()) { System.out.println(" Importing signature def " + signatureEntry.getKey() + " with method name " + signatureEntry.getValue().getMethodName()); - signatureEntry.getValue().getOutputsMap().values() - .forEach(output -> importOutput(output, signatureEntry.getValue().getMethodName(), graph.getGraphDef())); + Map inputs = importInputs(signatureEntry.getValue().getInputsMap()); + for (Map.Entry output : signatureEntry.getValue().getOutputsMap().entrySet()) { + try { + ExpressionNode result = importOutput(output.getValue(), + inputs, + graph.getGraphDef()); + expressions.add(new RankingExpression(output.getKey(), result)); + } + catch (IllegalArgumentException e) { + System.err.println("Skipping output '" + output.getValue().getName() + "' of signature '" + // TODO: Log, or ... + signatureEntry.getValue().getMethodName() + + "': " + Exceptions.toMessageString(e)); + } + } } + return expressions; + } + + private Map importInputs(Map inputInfoMap) { + Map inputs = new HashMap<>(); + inputInfoMap.forEach((key, value) -> inputs.put(nameOf(value.getName()), + importTensorType(value.getTensorShape()))); + return inputs; } - private void importOutput(TensorInfo output, String signatureName, GraphDef graph) { - try { - System.out.println(" Importing output " + output.getName()); - NodeDef node = getNode(nameOf(output.getName()), graph); - // System.out.println("Ops:-------------"); - // graph.getStrippedOpList().getOpList().stream().forEach(s -> System.out.println(s.getName())); - // System.out.println("-----------------"); - importNode(node, graph, ""); - } - catch (IllegalArgumentException e) { - System.err.println("Skipping output '" + output.getName() + "' of signature '" + signatureName + "': " + Exceptions.toMessageString(e)); + private TensorType importTensorType(TensorShapeProto tensorShape) { + TensorType.Builder b = new TensorType.Builder(); + for (int i = 0; i < tensorShape.getDimCount(); i++) { + int dimensionSize = (int) tensorShape.getDim(i).getSize(); + if (dimensionSize >= 0) + b.indexed("d" + i, dimensionSize); + else + b.indexed("d" + i); // unbound size } + return b.build(); } - private ExpressionNode importNode(NodeDef tfNode, GraphDef graph, String indent) { - System.out.println(" " + indent + "Importing node " + tfNode.getName()); - List arguments = new ArrayList<>(); - for (String input : tfNode.getInputList()) - arguments.add(importNode(getNode(nameOf(input), graph), graph, indent + " ")); - ExpressionNode node = expressionNodeOf(tfNode.getName(), arguments); + private ExpressionNode importOutput(TensorInfo output, Map inputs, GraphDef graph) { + System.out.println(" Importing output " + output.getName()); + NodeDef node = getNode(nameOf(output.getName()), graph); + return new TensorFunctionNode(importNode(node, inputs, graph, "").function()); + } + + /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ + private TypedTensorFunction importNode(NodeDef tfNode, Map inputs, GraphDef graph, String indent) { + System.out.println(" " + indent + "Importing node " + tfNode.getName() + " with operation " + tfNode.getOp()); + return tensorFunctionOf(tfNode, inputs, graph, indent); + } + + private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, + Map inputs, + GraphDef graph, + String indent) { + // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops + switch (tfNode.getOp()) { + case "Identity" : return identity(tfNode, inputs); + case "Add" : return join(importArguments(tfNode, inputs, graph, indent), ScalarFunctions.add()); + case "MatMul" : return matmul(importArguments(tfNode, inputs, graph, indent)); + case "Softmax" : return softmax(importArguments(tfNode, inputs, graph, indent)); + default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported"); + } } - private ExpressionNode expressionNodeOf(String node, List arguments) { - return new TensorFunctionNode(tensorFunctionOf(node, arguments.stream() - .map(TensorFunctionNode.TensorFunctionExpressionNode::new) - .collect(Collectors.toList()))); + private List importArguments(NodeDef tfNode, Map inputs, GraphDef graph, String indent) { + return tfNode.getInputList().stream() + .map(argNode -> importNode(getNode(nameOf(argNode), graph), inputs, graph, indent + " ")) + .collect(Collectors.toList()); } - private TensorFunction tensorFunctionOf(String node, List arguments) { - switch (node) { - case "add" : return new Join(arguments.get(0), arguments.get(1), ScalarFunctions.add()); - case "MatMul" : return new Matmul(arguments.get(0), arguments.get(1), ScalarFunctions.add()); + private TypedTensorFunction join(List arguments, DoubleBinaryOperator doubleFunction) { + ensureArguments(2, arguments, "join"); + TypedTensorFunction a = arguments.get(0); + TypedTensorFunction b = arguments.get(0); + // TODO: Verify with TF doc + TensorType resultType = Join.resultType(a.type(), b.type()); + Join function = new Join(a.function(), b.function(), doubleFunction); + return new TypedTensorFunction(resultType, function); + } + + private TypedTensorFunction matmul(List arguments) { + ensureArguments(2, arguments, "matmul"); + TypedTensorFunction a = arguments.get(0); + TypedTensorFunction b = arguments.get(0); + if (a.type().rank() < 2 || b.type.rank() < 2) + throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); + if (a.type().rank() != b.type.rank()) + throw new IllegalArgumentException("Tensors in matmul must have the same rank"); + + // Let the second-to-last dimension of the second tensor be the same as the last dimension of the first + // and the last dimension of the second argument be not present in the first argument, while leaving the + // rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication. + + // TODO: Check if transpose_a or transpose_b is set and rename differently accordingly + + String beforeLastDim = "d" + (a.type().rank() - 1); + String lastDim = "d" + a.type().rank(); + String afterLastDim = "d" + (a.type().rank() + 1); + + Rename renamedB = new Rename(b.function(), ImmutableList.of(beforeLastDim, lastDim), + ImmutableList.of(lastDim, afterLastDim)); + Matmul matmul = new Matmul(a.function(), renamedB, lastDim); + return new TypedTensorFunction(Matmul.resultType(a.type(), b.type(), lastDim), + new Rename(matmul, afterLastDim, lastDim)); + } + + private TypedTensorFunction softmax(List arguments) { + ensureArguments(1, arguments, "softmax"); + TypedTensorFunction a = arguments.get(0); + String dimension = "d0"; // TODO: Verify with TF doc + Softmax softmax = new Softmax(a.function(), dimension); + return new TypedTensorFunction(Softmax.resultType(a.type(), dimension), softmax); + } + + private TypedTensorFunction identity(NodeDef tfNode, Map inputs) { + // TODO: Verify with TF documentation + String name; + TensorType inputType; + if (tfNode.getName().endsWith("/read")) { // A node reading a variable supplied with this model TODO: We need to turn those into constants + if (tfNode.getInputList().size() != 1) + throw new IllegalArgumentException("A Variable/read node must have one input but has " + + tfNode.getInputList().size()); + name = tfNode.getInput(0); + AttrValue shapes = tfNode.getAttrMap().get("_output_shapes"); + if (shapes == null) + throw new IllegalArgumentException("Referenced variable '" + name + " is missing a tensor output shape"); + inputType = importTensorType(shapes.getList().getShape(0)); } + else { // a referenced input (query or document tensor) TODO: How to map to attribute/query name + name = tfNode.getName(); + inputType = inputs.get(name); + if (inputType == null) + throw new IllegalArgumentException("An identity operation node is referencing input '" + name + + "', but there is no such input"); + } + return new TypedTensorFunction(inputType, new VariableTensor(name)); + } + + private void ensureArguments(int count, List arguments, String operationName) { + if ( arguments.size() != count) + throw new IllegalArgumentException("Expected " + count + " arguments to " + operationName + + ", but got " + arguments.size()); } private NodeDef getNode(String name, GraphDef graph) { @@ -120,15 +254,31 @@ public class TensorFlowImporter { } /** - * An output has the form name:index. + * A method signature input and output has the form name:index. * This returns the name part without the index. */ - private String nameOf(String outputName) { - return outputName.split(":")[0]; + private String nameOf(String name) { + return name.split(":")[0]; } private boolean contains(String string, ProtocolStringList strings) { return strings.asByteStringList().stream().anyMatch(s -> s.toStringUtf8().equals(string)); } + + /** A tensor function returning a specific tensor type */ + private static final class TypedTensorFunction { + + private final TensorType type; + private final TensorFunction function; + + public TypedTensorFunction(TensorType type, TensorFunction function) { + this.type = type; + this.function = function; + } + + public TensorType type() { return type; } + public TensorFunction function() { return function; } + + } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java index 1f8db6e036c..ba765d07094 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java @@ -17,7 +17,7 @@ import java.util.Map; * @author bratseth */ public class SerializationContext { - + /** Expression functions indexed by name */ private final ImmutableMap functions; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index ce21e132980..ab5f1e7191d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -30,6 +30,9 @@ public class TensorFunctionNode extends CompositeNode { this.function = function; } + /** Returns the tensor function wrapped by this */ + public TensorFunction function() { return function; } + @Override public List children() { return function.functionArguments().stream() diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java index 4c511047118..30328c3d9fe 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java @@ -1,7 +1,13 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import org.junit.Test; +import java.util.List; + +import static org.junit.Assert.assertEquals; + /** * @author bratseth */ @@ -9,7 +15,21 @@ public class TensorFlowImporterTestCase { @Test public void testModel1() { - new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/"); + List expressions = + new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/"); + assertEquals(1, expressions.size()); + assertEquals("scores", expressions.get(0).getName()); + assertEquals("" + + "softmax(join(rename(matmul(x, rename(x, (d1, d2), (d2, d3)), d2), d3, d2), " + + "rename(matmul(x, rename(x, (d1, d2), (d2, d3)), d2), d3, d2), " + + "f(a,b)(a + b)), " + + "d0)", + toNonPrimitiveString(expressions.get(0))); + } + + private String toNonPrimitiveString(RankingExpression expression) { + // toString on the wrapping expression will map to primitives, which is harder to read + return ((TensorFunctionNode)expression.getRoot()).function().toString(); } } -- cgit v1.2.3