diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2017-12-21 08:54:11 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-12-21 08:54:11 +0100 |
commit | a23acc0d0f9f0f6c651bcf352bd2cfed6d33debc (patch) | |
tree | 3faba807627773ac594d7b90c45cebe64b44af3f /searchlib | |
parent | 431f9496782fa7d3d9c35abfcde2a2c1b3c55621 (diff) | |
parent | d97e6b0a72a7c7d38365475637bff3897bc7597a (diff) |
Merge pull request #4509 from vespa-engine/bratseth/tensorflow-models-2
Bratseth/tensorflow models 2
Diffstat (limited to 'searchlib')
4 files changed, 128 insertions, 77 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java index b4a9b363ade..947e6d7a5e1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java @@ -12,40 +12,91 @@ import java.util.Map; import java.util.stream.Collectors; /** - * The result of importing a TensorFlow model into Vespa: - * - A list of ranking expressions reproducing the computations of the outputs in the TensorFlow model - * - A list of named constant tensors - * - A list of expected input tensors, with their tensor type - * - A list of warning messages + * The result of importing a TensorFlow model into Vespa. + * - A set of signatures which are named collections of inputs and outputs. + * - A set of named constant tensors represented by Variable nodes in TensorFlow. + * - A list of warning messages. * * @author bratseth */ // This object can be built incrementally within this package, but is immutable when observed from outside the package -// TODO: Retain signature structure in ImportResult (input + output-expression bundles) public class ImportResult { - private final List<RankingExpression> expressions = new ArrayList<>(); - private final Map<String, Tensor> constants = new HashMap<>(); + private final Map<String, Signature> signatures = new HashMap<>(); private final Map<String, TensorType> arguments = new HashMap<>(); + private final Map<String, Tensor> constants = new HashMap<>(); + private final Map<String, RankingExpression> expressions = new HashMap<>(); private final List<String> warnings = new ArrayList<>(); - void add(RankingExpression expression) { expressions.add(expression); } - void set(String name, Tensor constant) { constants.put(name, constant); } - void set(String name, TensorType argument) { arguments.put(name, argument); } + void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } + void constant(String name, Tensor constant) { constants.put(name, constant); } + void expression(String name, RankingExpression expression) { expressions.put(name, expression); } void warn(String warning) { warnings.add(warning); } - /** Returns an immutable list of the expressions of this */ - public List<RankingExpression> expressions() { return Collections.unmodifiableList(expressions); } + /** Returns the given signature. If it does not already exist it is added to this. */ + Signature signature(String name) { + return signatures.computeIfAbsent(name, n -> new Signature(n)); + } + + /** Returns an immutable map of the arguments ("Placeholders") of this */ + public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } /** Returns an immutable map of the constants of this */ public Map<String, Tensor> constants() { return Collections.unmodifiableMap(constants); } - /** Returns an immutable map of the arguments of this */ - public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } + /** + * Returns an immutable map of the expressions of this - corresponding to TensorFlow nodes + * which are not Placeholders or Variables (which instead become respectively arguments and constants). + * Note that only nodes recursively referenced by a placeholder are added. + */ + public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } /** Returns an immutable list, in natural sort order of the warnings generated while importing this */ public List<String> warnings() { return warnings.stream().sorted().collect(Collectors.toList()); } + /** Returns an immutable map of the signatures of this */ + public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } + + /** + * A signature is a set of named inputs and outputs, where the inputs maps to argument ("placeholder") names+types, + * and outputs maps to expressions nodes. + */ + public class Signature { + + private final String name; + private final Map<String, String> inputs = new HashMap<>(); + private final Map<String, String> outputs = new HashMap<>(); + + Signature(String name) { + this.name = name; + } + + void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } + void output(String name, String expressionName) { outputs.put(name, expressionName); } + + /** Returns the result this is part of */ + ImportResult owner() { return ImportResult.this; } + + /** + * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name + * to argument (Placeholder) name in the owner of this + */ + public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); } + + /** Returns owner().arguments().get(inputs.get(name)), e.g the type of the argument this input references */ + public TensorType inputArgument(String inputName) { return owner().arguments().get(inputs.get(inputName)); } + + /** Returns an immutable list of the expression names of this */ + public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); } + + /** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */ + public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); } + + @Override + public String toString() { return "signature '" + name + "'"; } + + } + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index e7f7b5ef2f4..bac141644c6 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -90,8 +90,8 @@ class OperationMapper { String name = tfNode.getName(); TensorType type = result.arguments().get(name); if (type == null) - throw new IllegalArgumentException("An placeholder operation node is referencing input '" + name + - "', but there is no such input"); + throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name + + "', but there is no such placeholder"); // Included literally in the expression and so must be produced by a separate macro in the rank profile return new TypedTensorFunction(type, new VariableTensor(name)); } @@ -114,7 +114,7 @@ class OperationMapper { throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " + importedTensors.size()); Tensor constant = tensorConverter.toVespaTensor(importedTensors.get(0)); - result.set(name, constant); + result.constant(name, constant); return new TypedTensorFunction(constant.type(), new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")"))); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 33523244129..4a6551adca7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -1,11 +1,9 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.ScalarFunctions; -import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.yolean.Exceptions; import org.tensorflow.SavedModelBundle; import org.tensorflow.framework.GraphDef; @@ -38,49 +36,53 @@ public class TensorFlowImporter { */ public ImportResult importModel(String modelDir) { try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { - return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model); + return importModel(model); } - catch (IOException e) { - throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e); + catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); } } - public ImportResult importNode(String modelDir, String inputSignatureName, String nodeName) { - try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { - MetaGraphDef graph = MetaGraphDef.parseFrom(model.metaGraphDef()); - SignatureDef signature = graph.getSignatureDefMap().get(inputSignatureName); - ImportResult result = new ImportResult(); - importInputs(signature.getInputsMap(), result); - result.add(new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, result))); - return result; + /** Imports a TensorFlow model */ + public ImportResult importModel(SavedModelBundle model) { + try { + return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model); } catch (IOException e) { - throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e); + throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e); } } private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) { ImportResult result = new ImportResult(); for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) { - importInputs(signatureEntry.getValue().getInputsMap(), result); + ImportResult.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName" + + importInputs(signatureEntry.getValue().getInputsMap(), signature); for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) { + String outputName = output.getKey(); try { - ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, result); - result.add(new RankingExpression(output.getKey(), node)); + NodeDef node = getNode(nameOf(output.getValue().getName()), graph.getGraphDef()); + importNode(node, graph.getGraphDef(), model, result); + signature.output(outputName, nameOf(output.getValue().getName())); } catch (IllegalArgumentException e) { - result.warn("Skipping output '" + output.getValue().getName() + "' of signature '" + - signatureEntry.getValue().getMethodName() + - "': " + Exceptions.toMessageString(e)); + result.warn("Skipping output '" + outputName + "' of " + signature + + ": " + Exceptions.toMessageString(e)); } } } return result; } - private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult result) { - inputInfoMap.forEach((key, value) -> result.set(nameOf(value.getName()), - importTensorType(value.getTensorShape()))); + private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult.Signature signature) { + inputInfoMap.forEach((key, value) -> { + String argumentName = nameOf(value.getName()); + TensorType argumentType = importTensorType(value.getTensorShape()); + // Arguments are (Placeholder) nodes, so not local to the signature: + signature.owner().argument(argumentName, argumentType); + signature.input(key, argumentName); + }); } private TensorType importTensorType(TensorShapeProto tensorShape) { @@ -95,18 +97,13 @@ public class TensorFlowImporter { return b.build(); } - private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, ImportResult result) { - return importNode(nameOf(output.getName()), graph, model, result); - } - - private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, ImportResult result) { - TensorFunction function = importNode(getNode(nodeName, graph), graph, model, result).function(); - return new TensorFunctionNode(function); // wrap top level (only) as an expression - } - /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { - return tensorFunctionOf(tfNode, graph, model, result); + TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result); + // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output + // will be used + result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), new TensorFunctionNode(function.function()))); + return function; } private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { @@ -123,7 +120,8 @@ public class TensorFlowImporter { } } - private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { + private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, + ImportResult result) { return tfNode.getInputList().stream() .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result)) .collect(Collectors.toList()); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java index d50a97cc8e0..0370fc7fc94 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java @@ -27,18 +27,13 @@ public class Mnist_SoftmaxTestCase { @Test public void testImporting() { String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved"; - ImportResult result = new TensorFlowImporter().importModel(modelDir); + SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); + ImportResult result = new TensorFlowImporter().importModel(model); // Check logged messages result.warnings().forEach(System.err::println); assertEquals(0, result.warnings().size()); - // Check arguments - assertEquals(1, result.arguments().size()); - TensorType argument0 = result.arguments().get("Placeholder"); - assertNotNull(argument0); - assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0); - // Check constants assertEquals(2, result.constants().size()); @@ -54,38 +49,45 @@ public class Mnist_SoftmaxTestCase { constant1.type()); assertEquals(10, constant1.size()); - // Check resulting Vespa expression - assertEquals(1, result.expressions().size()); - assertEquals("y", result.expressions().get(0).getName()); + // Check signatures + assertEquals(1, result.signatures().size()); + ImportResult.Signature signature = result.signatures().get("serving_default"); + assertNotNull(signature); + + // ... signature inputs + assertEquals(1, signature.inputs().size()); + TensorType argument0 = signature.inputArgument("x"); + assertNotNull(argument0); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0); + + // ... signature outputs + assertEquals(1, signature.outputs().size()); + RankingExpression output = signature.outputExpression("y"); + assertNotNull(output); + assertEquals("add", output.getName()); assertEquals("" + "join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " + "rename(constant(Variable_1), d0, d1), " + "f(a,b)(a + b))", - toNonPrimitiveString(result.expressions().get(0))); + toNonPrimitiveString(output)); // Test execution - String signatureName = "serving_default"; - - assertEqualResult(modelDir, signatureName, "Variable/read"); - assertEqualResult(modelDir, signatureName, "Variable_1/read"); - // TODO: Assert that argument fed is as expected assertEqualResult(modelDir, signatureName, "Placeholder"); - assertEqualResult(modelDir, signatureName, "MatMul"); - assertEqualResult(modelDir, signatureName, "add"); + assertEqualResult(model, result, "Variable/read"); + assertEqualResult(model, result, "Variable_1/read"); + assertEqualResult(model, result, "MatMul"); + assertEqualResult(model, result, "add"); } - private void assertEqualResult(String modelDir, String signatureName, String operationName) { - ImportResult result = new TensorFlowImporter().importNode(modelDir, signatureName, operationName); - - Tensor tfResult = tensorFlowExecute(modelDir, operationName); + private void assertEqualResult(SavedModelBundle model, ImportResult result, String operationName) { + Tensor tfResult = tensorFlowExecute(model, operationName); Context context = contextFrom(result); Tensor placeholder = placeholderArgument(); context.put("Placeholder", new TensorValue(placeholder)); - Tensor vespaResult = result.expressions().get(0).evaluate(context).asTensor(); + Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor(); assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult); } - private Tensor tensorFlowExecute(String modelDir, String operationName) { - SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); + private Tensor tensorFlowExecute(SavedModelBundle model, String operationName) { Session.Runner runner = model.session().runner(); org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784)); runner.feed("Placeholder", placeholder); |