From 5dce0c978c36936c7372e32d1a05f05c0b61386e Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Wed, 20 Dec 2017 15:13:57 +0100 Subject: Model signatures in import results --- .../integration/tensorflow/ImportResult.java | 70 ++++++++++++++++----- .../integration/tensorflow/OperationMapper.java | 12 ++-- .../integration/tensorflow/TensorFlowImporter.java | 71 +++++++++++++--------- .../tensorflow/Mnist_SoftmaxTestCase.java | 31 ++++++---- 4 files changed, 123 insertions(+), 61 deletions(-) (limited to 'searchlib') diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java index b4a9b363ade..b3c1708a0f4 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java @@ -12,40 +12,80 @@ import java.util.Map; import java.util.stream.Collectors; /** - * The result of importing a TensorFlow model into Vespa: - * - A list of ranking expressions reproducing the computations of the outputs in the TensorFlow model - * - A list of named constant tensors - * - A list of expected input tensors, with their tensor type - * - A list of warning messages + * The result of importing a TensorFlow model into Vespa. + * - A set of signatures which are named collections of inputs and outputs. + * - A set of named constant tensors represented by Variable nodes in TensorFlow. + * - A list of warning messages. * * @author bratseth */ // This object can be built incrementally within this package, but is immutable when observed from outside the package -// TODO: Retain signature structure in ImportResult (input + output-expression bundles) public class ImportResult { - private final List expressions = new ArrayList<>(); - private final Map constants = new HashMap<>(); + private final Map signatures = new HashMap<>(); private final Map arguments = new HashMap<>(); + private final Map constants = new HashMap<>(); private final List warnings = new ArrayList<>(); - void add(RankingExpression expression) { expressions.add(expression); } - void set(String name, Tensor constant) { constants.put(name, constant); } - void set(String name, TensorType argument) { arguments.put(name, argument); } + void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } + void constant(String name, Tensor constant) { constants.put(name, constant); } void warn(String warning) { warnings.add(warning); } - /** Returns an immutable list of the expressions of this */ - public List expressions() { return Collections.unmodifiableList(expressions); } + /** Returns the given signature. If it does not already exist it is added to this. */ + Signature signature(String name) { + return signatures.computeIfAbsent(name, n -> new Signature(n)); + } + + /** Returns an immutable map of the arguments ("Placeholders") of this */ + public Map arguments() { return Collections.unmodifiableMap(arguments); } /** Returns an immutable map of the constants of this */ public Map constants() { return Collections.unmodifiableMap(constants); } - /** Returns an immutable map of the arguments of this */ - public Map arguments() { return Collections.unmodifiableMap(arguments); } + /** Returns an immutable map of the signatures of this */ + public Map signatures() { return Collections.unmodifiableMap(signatures); } /** Returns an immutable list, in natural sort order of the warnings generated while importing this */ public List warnings() { return warnings.stream().sorted().collect(Collectors.toList()); } + /** + * A signature is a set of named inputs and outputs, where the inputs maps to argument ("placeholder") names+types, + * and outputs maps to ranking expressions stemming from conversion of TensorFlow nodes and the inputs make up the + * context which is needed to evaluate the expression. + */ + public class Signature { + + private final String name; + private final Map inputs = new HashMap<>(); + private final Map outputs = new HashMap<>(); + + Signature(String name) { + this.name = name; + } + + void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } + void output(String name, RankingExpression expression) { outputs.put(name, expression); } + + /** Returns the result this is part of */ + ImportResult owner() { return ImportResult.this; } + + /** + * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name + * to argument (Placeholder) name in the owner of this + */ + public Map inputs() { return Collections.unmodifiableMap(inputs); } + + /** Returns owner().arguments().get(inputs.get(name)), e.g the type of the argument this input references */ + public TensorType inputType(String inputName) { return owner().arguments().get(inputs.get(inputName)); } + + /** Returns an immutable list of the expressions of this */ + public Map outputs() { return Collections.unmodifiableMap(outputs); } + + @Override + public String toString() { return "signature '" + name + "'"; } + + } + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index e7f7b5ef2f4..5e2d7530200 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -86,17 +86,17 @@ class OperationMapper { return new TypedTensorFunction(resultType, function); } - TypedTensorFunction placeholder(NodeDef tfNode, ImportResult result) { + TypedTensorFunction placeholder(NodeDef tfNode, ImportResult.Signature signature) { String name = tfNode.getName(); - TensorType type = result.arguments().get(name); + TensorType type = signature.owner().arguments().get(name); if (type == null) - throw new IllegalArgumentException("An placeholder operation node is referencing input '" + name + - "', but there is no such input"); + throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name + + "', but there is no such placeholder"); // Included literally in the expression and so must be produced by a separate macro in the rank profile return new TypedTensorFunction(type, new VariableTensor(name)); } - TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult result) { + TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult.Signature signature) { if ( ! tfNode.getName().endsWith("/read")) throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " + "nodes are only supported when reading variables"); @@ -114,7 +114,7 @@ class OperationMapper { throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " + importedTensors.size()); Tensor constant = tensorConverter.toVespaTensor(importedTensors.get(0)); - result.set(name, constant); + signature.owner().constant(name, constant); return new TypedTensorFunction(constant.type(), new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")"))); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 33523244129..66bfe9fcfb9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -45,13 +45,16 @@ public class TensorFlowImporter { } } - public ImportResult importNode(String modelDir, String inputSignatureName, String nodeName) { + /** Imports a specific node as an putput given the name of that node. Useful for testing */ + public ImportResult importNode(String modelDir, String signatureName, String nodeName) { try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { MetaGraphDef graph = MetaGraphDef.parseFrom(model.metaGraphDef()); - SignatureDef signature = graph.getSignatureDefMap().get(inputSignatureName); + SignatureDef signatureDef = graph.getSignatureDefMap().get(signatureName); ImportResult result = new ImportResult(); - importInputs(signature.getInputsMap(), result); - result.add(new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, result))); + ImportResult.Signature signature = result.signature(signatureName); + importInputs(signatureDef.getInputsMap(), signature); + signature.output(nodeName, + new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, signature))); return result; } catch (IOException e) { @@ -62,25 +65,32 @@ public class TensorFlowImporter { private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) { ImportResult result = new ImportResult(); for (Map.Entry signatureEntry : graph.getSignatureDefMap().entrySet()) { - importInputs(signatureEntry.getValue().getInputsMap(), result); + ImportResult.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName" + + importInputs(signatureEntry.getValue().getInputsMap(), signature); for (Map.Entry output : signatureEntry.getValue().getOutputsMap().entrySet()) { + String outputName = output.getKey(); try { - ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, result); - result.add(new RankingExpression(output.getKey(), node)); + ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, signature); + signature.output(outputName, new RankingExpression(outputName, node)); } catch (IllegalArgumentException e) { - result.warn("Skipping output '" + output.getValue().getName() + "' of signature '" + - signatureEntry.getValue().getMethodName() + - "': " + Exceptions.toMessageString(e)); + result.warn("Skipping output '" + outputName + "' of " + signature + + ": " + Exceptions.toMessageString(e)); } } } return result; } - private void importInputs(Map inputInfoMap, ImportResult result) { - inputInfoMap.forEach((key, value) -> result.set(nameOf(value.getName()), - importTensorType(value.getTensorShape()))); + private void importInputs(Map inputInfoMap, ImportResult.Signature signature) { + inputInfoMap.forEach((key, value) -> { + String argumentName = nameOf(value.getName()); + TensorType argumentType = importTensorType(value.getTensorShape()); + // Arguments are (Placeholder) nodes, so not local to the signature: + signature.owner().argument(argumentName, argumentType); + signature.input(key, argumentName); + }); } private TensorType importTensorType(TensorShapeProto tensorShape) { @@ -95,37 +105,42 @@ public class TensorFlowImporter { return b.build(); } - private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, ImportResult result) { - return importNode(nameOf(output.getName()), graph, model, result); + private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, + ImportResult.Signature signature) { + return importNode(nameOf(output.getName()), graph, model, signature); } - private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, ImportResult result) { - TensorFunction function = importNode(getNode(nodeName, graph), graph, model, result).function(); + private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, + ImportResult.Signature signature) { + TensorFunction function = importNode(getNode(nodeName, graph), graph, model, signature).function(); return new TensorFunctionNode(function); // wrap top level (only) as an expression } /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ - private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { - return tensorFunctionOf(tfNode, graph, model, result); + private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, + ImportResult.Signature signature) { + return tensorFunctionOf(tfNode, graph, model, signature); } - private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { + private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, + ImportResult.Signature signature) { // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/ switch (tfNode.getOp().toLowerCase()) { - case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add()); - case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.acos()); - case "placeholder" : return operationMapper.placeholder(tfNode, result); - case "identity" : return operationMapper.identity(tfNode, model, result); - case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, result)); - case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, result)); + case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, signature), ScalarFunctions.add()); + case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, signature), ScalarFunctions.acos()); + case "placeholder" : return operationMapper.placeholder(tfNode, signature); + case "identity" : return operationMapper.identity(tfNode, model, signature); + case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, signature)); + case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, signature)); default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported"); } } - private List importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { + private List importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, + ImportResult.Signature signature) { return tfNode.getInputList().stream() - .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result)) + .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, signature)) .collect(Collectors.toList()); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java index d50a97cc8e0..53989af4460 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java @@ -33,12 +33,6 @@ public class Mnist_SoftmaxTestCase { result.warnings().forEach(System.err::println); assertEquals(0, result.warnings().size()); - // Check arguments - assertEquals(1, result.arguments().size()); - TensorType argument0 = result.arguments().get("Placeholder"); - assertNotNull(argument0); - assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0); - // Check constants assertEquals(2, result.constants().size()); @@ -54,18 +48,31 @@ public class Mnist_SoftmaxTestCase { constant1.type()); assertEquals(10, constant1.size()); - // Check resulting Vespa expression - assertEquals(1, result.expressions().size()); - assertEquals("y", result.expressions().get(0).getName()); + // Check signatures + assertEquals(1, result.signatures().size()); + ImportResult.Signature signature = result.signatures().get("serving_default"); + assertNotNull(signature); + + // ... signature inputs + assertEquals(1, signature.inputs().size()); + TensorType argument0 = signature.inputType("x"); + assertNotNull(argument0); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0); + + // ... signature outputs + assertEquals(1, signature.outputs().size()); + RankingExpression output = signature.outputs().get("y"); + assertNotNull(output); + assertEquals("y", output.getName()); assertEquals("" + "join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " + "rename(constant(Variable_1), d0, d1), " + "f(a,b)(a + b))", - toNonPrimitiveString(result.expressions().get(0))); + toNonPrimitiveString(output)); // Test execution + // TODO: Pass imported result instead of re-importing String signatureName = "serving_default"; - assertEqualResult(modelDir, signatureName, "Variable/read"); assertEqualResult(modelDir, signatureName, "Variable_1/read"); // TODO: Assert that argument fed is as expected assertEqualResult(modelDir, signatureName, "Placeholder"); @@ -80,7 +87,7 @@ public class Mnist_SoftmaxTestCase { Context context = contextFrom(result); Tensor placeholder = placeholderArgument(); context.put("Placeholder", new TensorValue(placeholder)); - Tensor vespaResult = result.expressions().get(0).evaluate(context).asTensor(); + Tensor vespaResult = result.signatures().get(signatureName).outputs().get(operationName).evaluate(context).asTensor(); assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult); } -- cgit v1.2.3