From 3df3c57607c73bda31a60af6695aeafd8a57fabb Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Sat, 2 Dec 2017 16:52:34 -0800 Subject: Import and return constant tensors --- .../integration/tensorflow/NamedTensor.java | 23 +++++++++++ .../integration/tensorflow/OperationMapper.java | 27 ++++++------ .../integration/tensorflow/TensorFlowImporter.java | 48 ++++++++++++++-------- .../rankingexpression/rule/TensorFunctionNode.java | 26 ++++++------ .../tensorflow/TensorFlowImporterTestCase.java | 23 +++++++++-- 5 files changed, 100 insertions(+), 47 deletions(-) create mode 100644 searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java (limited to 'searchlib') diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java new file mode 100644 index 00000000000..235771bfa9c --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/NamedTensor.java @@ -0,0 +1,23 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.tensor.Tensor; + +/** + * A tensor with a name + * + * @author bratseth + */ +public class NamedTensor { + + private final String name; + private final Tensor tensor; + + public NamedTensor(String name, Tensor tensor) { + this.name = name; + this.tensor = tensor; + } + + public String name() { return name; } + public Tensor tensor() { return tensor; } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index 0717d3e1b2b..183cfabbd87 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -1,10 +1,11 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.VariableTensor; -import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.Matmul; import com.yahoo.tensor.functions.Rename; @@ -49,7 +50,7 @@ class OperationMapper { // (and if not, this should do the right thing anyway) ensureArguments(2, arguments, "join"); TypedTensorFunction a = arguments.get(0); - TypedTensorFunction b = arguments.get(0); + TypedTensorFunction b = arguments.get(1); TensorType resultType = Join.outputType(a.type(), b.type()); Join function = new Join(a.function(), b.function(), doubleFunction); @@ -65,9 +66,10 @@ class OperationMapper { return new TypedTensorFunction(resultType, function); } - TypedTensorFunction identity(NodeDef tfNode, Map inputs, SavedModelBundle model) { + TypedTensorFunction identity(NodeDef tfNode, Map inputs, SavedModelBundle model, + List constants) { String name; - TensorType inputType; + TensorType type; if (tfNode.getName().endsWith("/read")) { // A node reading a variable supplied with this model if (tfNode.getInputList().size() != 1) throw new IllegalArgumentException("A Variable/read node must have one input but has " + @@ -75,29 +77,30 @@ class OperationMapper { name = tfNode.getInput(0); AttrValue shapes = tfNode.getAttrMap().get("_output_shapes"); if (shapes == null) - throw new IllegalArgumentException("Referenced variable '" + name + " is missing a tensor output shape"); - inputType = TensorFlowImporter.importTensorType(shapes.getList().getShape(0)); - Session.Runner fetched = model.session().runner().fetch("Variable"); + throw new IllegalArgumentException("Referenced variable '" + name + "' is missing a tensor output shape"); + Session.Runner fetched = model.session().runner().fetch(name); List> result = fetched.run(); if ( result.size() != 1) throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " + result.size()); Tensor constant = tensorConverter.toVespaTensor(result.get(0)); - return new TypedTensorFunction(inputType, new ConstantTensor(constant)); + constants.add(new NamedTensor(name, constant)); + return new TypedTensorFunction(constant.type(), + new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")"))); } else { // a referenced input (query or document tensor) TODO: How to map to attribute/query name name = tfNode.getName(); - inputType = inputs.get(name); - if (inputType == null) + type = inputs.get(name); + if (type == null) throw new IllegalArgumentException("An identity operation node is referencing input '" + name + "', but there is no such input"); - return new TypedTensorFunction(inputType, new VariableTensor(name)); + return new TypedTensorFunction(type, new VariableTensor(name)); } } TypedTensorFunction matmul(List arguments) { ensureArguments(2, arguments, "matmul"); TypedTensorFunction a = arguments.get(0); - TypedTensorFunction b = arguments.get(0); + TypedTensorFunction b = arguments.get(1); if (a.type().rank() < 2 || b.type().rank() < 2) throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); if (a.type().rank() != b.type().rank()) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index c14f8c71a3e..51f1e444e70 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -5,6 +5,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.yolean.Exceptions; import org.tensorflow.SavedModelBundle; import org.tensorflow.framework.GraphDef; @@ -35,12 +36,16 @@ public class TensorFlowImporter { * Imports a saved TensorFlow model from a directory. * The model should be saved as a pbtxt file. * The name of the model is taken at the pbtxt file name (not including the .pbtxt ending). + * + * @param modelDir the directory containing the TensorFlow model files to import + * @param constants any constant tensors imported from the TensorFlow model and referenced in the returned expressions + * @param logger a receiver of any messages generated by the import process + * @return the ranking expressions resulting from importing this TenorFlow model */ - public List importModel(String modelDir, MessageLogger logger) { + public List importModel(String modelDir, List constants, MessageLogger logger) { try { SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); - return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model, logger); - + return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model, constants, logger); } catch (IOException e) { throw new IllegalArgumentException("Could not open TensorFlow model directory '" + modelDir + "'", e); @@ -48,7 +53,8 @@ public class TensorFlowImporter { } - private List importGraph(MetaGraphDef graph, SavedModelBundle model, MessageLogger logger) { + private List importGraph(MetaGraphDef graph, SavedModelBundle model, + List constants, MessageLogger logger) { List expressions = new ArrayList<>(); for (Map.Entry signatureEntry : graph.getSignatureDefMap().entrySet()) { Map inputs = importInputs(signatureEntry.getValue().getInputsMap()); @@ -57,7 +63,8 @@ public class TensorFlowImporter { ExpressionNode result = importOutput(output.getValue(), inputs, graph.getGraphDef(), - model); + model, + constants); expressions.add(new RankingExpression(output.getKey(), result)); } catch (IllegalArgumentException e) { @@ -77,7 +84,7 @@ public class TensorFlowImporter { return inputs; } - static TensorType importTensorType(TensorShapeProto tensorShape) { + private TensorType importTensorType(TensorShapeProto tensorShape) { TensorType.Builder b = new TensorType.Builder(); for (int i = 0; i < tensorShape.getDimCount(); i++) { int dimensionSize = (int) tensorShape.getDim(i).getSize(); @@ -89,28 +96,32 @@ public class TensorFlowImporter { return b.build(); } - private ExpressionNode importOutput(TensorInfo output, Map inputs, GraphDef graph, SavedModelBundle model) { + private ExpressionNode importOutput(TensorInfo output, Map inputs, GraphDef graph, + SavedModelBundle model, List constants) { NodeDef node = getNode(nameOf(output.getName()), graph); - return new TensorFunctionNode(importNode(node, inputs, graph, model).function()); + TensorFunction function = importNode(node, inputs, graph, model, constants).function(); + return new TensorFunctionNode(function); // wrap top level (only) as an expression } /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ - private TypedTensorFunction importNode(NodeDef tfNode, Map inputs, GraphDef graph, SavedModelBundle model) { - return tensorFunctionOf(tfNode, inputs, graph, model); + private TypedTensorFunction importNode(NodeDef tfNode, Map inputs, GraphDef graph, + SavedModelBundle model, List constants) { + return tensorFunctionOf(tfNode, inputs, graph, model, constants); } private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, Map inputs, GraphDef graph, - SavedModelBundle model) { + SavedModelBundle model, + List constants) { // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/ switch (tfNode.getOp().toLowerCase()) { - case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, inputs, graph, model), ScalarFunctions.add()); - case "acos" : return operationMapper.map(importArguments(tfNode, inputs, graph, model), ScalarFunctions.acos()); - case "identity" : return operationMapper.identity(tfNode, inputs, model); - case "matmul" : return operationMapper.matmul(importArguments(tfNode, inputs, graph, model)); - case "softmax" : return operationMapper.softmax(importArguments(tfNode, inputs, graph, model)); + case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, inputs, graph, model, constants), ScalarFunctions.add()); + case "acos" : return operationMapper.map(importArguments(tfNode, inputs, graph, model, constants), ScalarFunctions.acos()); + case "identity" : return operationMapper.identity(tfNode, inputs, model, constants); + case "matmul" : return operationMapper.matmul(importArguments(tfNode, inputs, graph, model, constants)); + case "softmax" : return operationMapper.softmax(importArguments(tfNode, inputs, graph, model, constants)); default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported"); } } @@ -118,9 +129,10 @@ public class TensorFlowImporter { private List importArguments(NodeDef tfNode, Map inputs, GraphDef graph, - SavedModelBundle model) { + SavedModelBundle model, + List constants) { return tfNode.getInputList().stream() - .map(argNode -> importNode(getNode(nameOf(argNode), graph), inputs, graph, model)) + .map(argNode -> importNode(getNode(nameOf(argNode), graph), inputs, graph, model, constants)) .collect(Collectors.toList()); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index ab5f1e7191d..d1f4cbddf6e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -21,18 +21,18 @@ import java.util.stream.Collectors; * * @author bratseth */ - @Beta +@Beta public class TensorFunctionNode extends CompositeNode { private final TensorFunction function; - + public TensorFunctionNode(TensorFunction function) { this.function = function; } /** Returns the tensor function wrapped by this */ public TensorFunction function() { return function; } - + @Override public List children() { return function.functionArguments().stream() @@ -53,7 +53,7 @@ public class TensorFunctionNode extends CompositeNode { // Serialize as primitive return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this)); } - + @Override public Value evaluate(Context context) { return new TensorValue(function.evaluate(context)); @@ -62,8 +62,8 @@ public class TensorFunctionNode extends CompositeNode { public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) { return new TensorFunctionExpressionNode(node); } - - /** + + /** * A tensor function implemented by an expression. * This allows us to pass expressions as tensor function arguments. */ @@ -71,13 +71,13 @@ public class TensorFunctionNode extends CompositeNode { /** An expression which produces a tensor */ private final ExpressionNode expression; - + public TensorFunctionExpressionNode(ExpressionNode expression) { this.expression = expression; } - + @Override - public List functionArguments() { + public List functionArguments() { if (expression instanceof CompositeNode) return ((CompositeNode)expression).children().stream() .map(TensorFunctionExpressionNode::new) @@ -111,7 +111,7 @@ public class TensorFunctionNode extends CompositeNode { public String toString() { return toString(ExpressionNodeToStringContext.empty); } - + @Override public String toString(ToStringContext c) { if (c instanceof ExpressionNodeToStringContext) { @@ -124,14 +124,14 @@ public class TensorFunctionNode extends CompositeNode { } } - + /** Allows passing serialization context arguments through TensorFunctions */ private static class ExpressionNodeToStringContext implements ToStringContext { - + final SerializationContext context; final Deque path; final CompositeNode parent; - + public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(null, null, null); public ExpressionNodeToStringContext(SerializationContext context, Deque path, CompositeNode parent) { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java index 936936dc3eb..aaf198a9e8f 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.TensorType; import org.junit.Test; import java.util.ArrayList; @@ -18,9 +19,23 @@ public class TensorFlowImporterTestCase { @Test public void testModel1() { + List constants = new ArrayList<>(); TestLogger logger = new TestLogger(); List expressions = - new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/", logger); + new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/", constants, logger); + + // Check constants + assertEquals(2, constants.size()); + + assertEquals("Variable", constants.get(0).name()); + assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), + constants.get(0).tensor().type()); + assertEquals(7840, constants.get(0).tensor().size()); + + assertEquals("Variable_1", constants.get(1).name()); + assertEquals(new TensorType.Builder().indexed("d0", 10).build(), + constants.get(1).tensor().type()); + assertEquals(10, constants.get(1).tensor().size()); // Check logged messages assertEquals(2, logger.messages().size()); @@ -33,10 +48,10 @@ public class TensorFlowImporterTestCase { assertEquals(1, expressions.size()); assertEquals("scores", expressions.get(0).getName()); assertEquals("" + - "softmax(join(rename(matmul(x, rename(x, (d1, d2), (d2, d3)), d2), d3, d2), " + - "rename(matmul(x, rename(x, (d1, d2), (d2, d3)), d2), d3, d2), " + + "softmax(join(rename(matmul(x, rename(constant(Variable), (d1, d2), (d2, d3)), d2), d3, d2), " + + "constant(Variable_1), " + "f(a,b)(a + b)), " + - "d1)", + "d0)", toNonPrimitiveString(expressions.get(0))); } -- cgit v1.2.3