summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-12-20 15:13:57 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-12-20 15:13:57 +0100
commit5dce0c978c36936c7372e32d1a05f05c0b61386e (patch)
treeaf890c5e82aab612dd7f4943ddf18685ab2ca99d /searchlib
parent2be1c34825fda8bdf9711c1e7522989fe3a8a45e (diff)
Model signatures in import results
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java70
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java12
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java71
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java31
4 files changed, 123 insertions, 61 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
index b4a9b363ade..b3c1708a0f4 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
@@ -12,40 +12,80 @@ import java.util.Map;
import java.util.stream.Collectors;
/**
- * The result of importing a TensorFlow model into Vespa:
- * - A list of ranking expressions reproducing the computations of the outputs in the TensorFlow model
- * - A list of named constant tensors
- * - A list of expected input tensors, with their tensor type
- * - A list of warning messages
+ * The result of importing a TensorFlow model into Vespa.
+ * - A set of signatures which are named collections of inputs and outputs.
+ * - A set of named constant tensors represented by Variable nodes in TensorFlow.
+ * - A list of warning messages.
*
* @author bratseth
*/
// This object can be built incrementally within this package, but is immutable when observed from outside the package
-// TODO: Retain signature structure in ImportResult (input + output-expression bundles)
public class ImportResult {
- private final List<RankingExpression> expressions = new ArrayList<>();
- private final Map<String, Tensor> constants = new HashMap<>();
+ private final Map<String, Signature> signatures = new HashMap<>();
private final Map<String, TensorType> arguments = new HashMap<>();
+ private final Map<String, Tensor> constants = new HashMap<>();
private final List<String> warnings = new ArrayList<>();
- void add(RankingExpression expression) { expressions.add(expression); }
- void set(String name, Tensor constant) { constants.put(name, constant); }
- void set(String name, TensorType argument) { arguments.put(name, argument); }
+ void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
+ void constant(String name, Tensor constant) { constants.put(name, constant); }
void warn(String warning) { warnings.add(warning); }
- /** Returns an immutable list of the expressions of this */
- public List<RankingExpression> expressions() { return Collections.unmodifiableList(expressions); }
+ /** Returns the given signature. If it does not already exist it is added to this. */
+ Signature signature(String name) {
+ return signatures.computeIfAbsent(name, n -> new Signature(n));
+ }
+
+ /** Returns an immutable map of the arguments ("Placeholders") of this */
+ public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
/** Returns an immutable map of the constants of this */
public Map<String, Tensor> constants() { return Collections.unmodifiableMap(constants); }
- /** Returns an immutable map of the arguments of this */
- public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
+ /** Returns an immutable map of the signatures of this */
+ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
/** Returns an immutable list, in natural sort order of the warnings generated while importing this */
public List<String> warnings() {
return warnings.stream().sorted().collect(Collectors.toList());
}
+ /**
+ * A signature is a set of named inputs and outputs, where the inputs maps to argument ("placeholder") names+types,
+ * and outputs maps to ranking expressions stemming from conversion of TensorFlow nodes and the inputs make up the
+ * context which is needed to evaluate the expression.
+ */
+ public class Signature {
+
+ private final String name;
+ private final Map<String, String> inputs = new HashMap<>();
+ private final Map<String, RankingExpression> outputs = new HashMap<>();
+
+ Signature(String name) {
+ this.name = name;
+ }
+
+ void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
+ void output(String name, RankingExpression expression) { outputs.put(name, expression); }
+
+ /** Returns the result this is part of */
+ ImportResult owner() { return ImportResult.this; }
+
+ /**
+ * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
+ * to argument (Placeholder) name in the owner of this
+ */
+ public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
+
+ /** Returns owner().arguments().get(inputs.get(name)), e.g the type of the argument this input references */
+ public TensorType inputType(String inputName) { return owner().arguments().get(inputs.get(inputName)); }
+
+ /** Returns an immutable list of the expressions of this */
+ public Map<String, RankingExpression> outputs() { return Collections.unmodifiableMap(outputs); }
+
+ @Override
+ public String toString() { return "signature '" + name + "'"; }
+
+ }
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index e7f7b5ef2f4..5e2d7530200 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -86,17 +86,17 @@ class OperationMapper {
return new TypedTensorFunction(resultType, function);
}
- TypedTensorFunction placeholder(NodeDef tfNode, ImportResult result) {
+ TypedTensorFunction placeholder(NodeDef tfNode, ImportResult.Signature signature) {
String name = tfNode.getName();
- TensorType type = result.arguments().get(name);
+ TensorType type = signature.owner().arguments().get(name);
if (type == null)
- throw new IllegalArgumentException("An placeholder operation node is referencing input '" + name +
- "', but there is no such input");
+ throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name +
+ "', but there is no such placeholder");
// Included literally in the expression and so must be produced by a separate macro in the rank profile
return new TypedTensorFunction(type, new VariableTensor(name));
}
- TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult result) {
+ TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult.Signature signature) {
if ( ! tfNode.getName().endsWith("/read"))
throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " +
"nodes are only supported when reading variables");
@@ -114,7 +114,7 @@ class OperationMapper {
throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " +
importedTensors.size());
Tensor constant = tensorConverter.toVespaTensor(importedTensors.get(0));
- result.set(name, constant);
+ signature.owner().constant(name, constant);
return new TypedTensorFunction(constant.type(),
new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")")));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 33523244129..66bfe9fcfb9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -45,13 +45,16 @@ public class TensorFlowImporter {
}
}
- public ImportResult importNode(String modelDir, String inputSignatureName, String nodeName) {
+ /** Imports a specific node as an putput given the name of that node. Useful for testing */
+ public ImportResult importNode(String modelDir, String signatureName, String nodeName) {
try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
MetaGraphDef graph = MetaGraphDef.parseFrom(model.metaGraphDef());
- SignatureDef signature = graph.getSignatureDefMap().get(inputSignatureName);
+ SignatureDef signatureDef = graph.getSignatureDefMap().get(signatureName);
ImportResult result = new ImportResult();
- importInputs(signature.getInputsMap(), result);
- result.add(new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, result)));
+ ImportResult.Signature signature = result.signature(signatureName);
+ importInputs(signatureDef.getInputsMap(), signature);
+ signature.output(nodeName,
+ new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, signature)));
return result;
}
catch (IOException e) {
@@ -62,25 +65,32 @@ public class TensorFlowImporter {
private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) {
ImportResult result = new ImportResult();
for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
- importInputs(signatureEntry.getValue().getInputsMap(), result);
+ ImportResult.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName"
+
+ importInputs(signatureEntry.getValue().getInputsMap(), signature);
for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) {
+ String outputName = output.getKey();
try {
- ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, result);
- result.add(new RankingExpression(output.getKey(), node));
+ ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, signature);
+ signature.output(outputName, new RankingExpression(outputName, node));
}
catch (IllegalArgumentException e) {
- result.warn("Skipping output '" + output.getValue().getName() + "' of signature '" +
- signatureEntry.getValue().getMethodName() +
- "': " + Exceptions.toMessageString(e));
+ result.warn("Skipping output '" + outputName + "' of " + signature +
+ ": " + Exceptions.toMessageString(e));
}
}
}
return result;
}
- private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult result) {
- inputInfoMap.forEach((key, value) -> result.set(nameOf(value.getName()),
- importTensorType(value.getTensorShape())));
+ private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult.Signature signature) {
+ inputInfoMap.forEach((key, value) -> {
+ String argumentName = nameOf(value.getName());
+ TensorType argumentType = importTensorType(value.getTensorShape());
+ // Arguments are (Placeholder) nodes, so not local to the signature:
+ signature.owner().argument(argumentName, argumentType);
+ signature.input(key, argumentName);
+ });
}
private TensorType importTensorType(TensorShapeProto tensorShape) {
@@ -95,37 +105,42 @@ public class TensorFlowImporter {
return b.build();
}
- private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, ImportResult result) {
- return importNode(nameOf(output.getName()), graph, model, result);
+ private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model,
+ ImportResult.Signature signature) {
+ return importNode(nameOf(output.getName()), graph, model, signature);
}
- private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, ImportResult result) {
- TensorFunction function = importNode(getNode(nodeName, graph), graph, model, result).function();
+ private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model,
+ ImportResult.Signature signature) {
+ TensorFunction function = importNode(getNode(nodeName, graph), graph, model, signature).function();
return new TensorFunctionNode(function); // wrap top level (only) as an expression
}
/** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
- private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
- return tensorFunctionOf(tfNode, graph, model, result);
+ private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model,
+ ImportResult.Signature signature) {
+ return tensorFunctionOf(tfNode, graph, model, signature);
}
- private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model,
+ ImportResult.Signature signature) {
// Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops
// TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/
switch (tfNode.getOp().toLowerCase()) {
- case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add());
- case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.acos());
- case "placeholder" : return operationMapper.placeholder(tfNode, result);
- case "identity" : return operationMapper.identity(tfNode, model, result);
- case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, result));
- case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, result));
+ case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, signature), ScalarFunctions.add());
+ case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, signature), ScalarFunctions.acos());
+ case "placeholder" : return operationMapper.placeholder(tfNode, signature);
+ case "identity" : return operationMapper.identity(tfNode, model, signature);
+ case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, signature));
+ case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, signature));
default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported");
}
}
- private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model,
+ ImportResult.Signature signature) {
return tfNode.getInputList().stream()
- .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result))
+ .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, signature))
.collect(Collectors.toList());
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
index d50a97cc8e0..53989af4460 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
@@ -33,12 +33,6 @@ public class Mnist_SoftmaxTestCase {
result.warnings().forEach(System.err::println);
assertEquals(0, result.warnings().size());
- // Check arguments
- assertEquals(1, result.arguments().size());
- TensorType argument0 = result.arguments().get("Placeholder");
- assertNotNull(argument0);
- assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);
-
// Check constants
assertEquals(2, result.constants().size());
@@ -54,18 +48,31 @@ public class Mnist_SoftmaxTestCase {
constant1.type());
assertEquals(10, constant1.size());
- // Check resulting Vespa expression
- assertEquals(1, result.expressions().size());
- assertEquals("y", result.expressions().get(0).getName());
+ // Check signatures
+ assertEquals(1, result.signatures().size());
+ ImportResult.Signature signature = result.signatures().get("serving_default");
+ assertNotNull(signature);
+
+ // ... signature inputs
+ assertEquals(1, signature.inputs().size());
+ TensorType argument0 = signature.inputType("x");
+ assertNotNull(argument0);
+ assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);
+
+ // ... signature outputs
+ assertEquals(1, signature.outputs().size());
+ RankingExpression output = signature.outputs().get("y");
+ assertNotNull(output);
+ assertEquals("y", output.getName());
assertEquals("" +
"join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " +
"rename(constant(Variable_1), d0, d1), " +
"f(a,b)(a + b))",
- toNonPrimitiveString(result.expressions().get(0)));
+ toNonPrimitiveString(output));
// Test execution
+ // TODO: Pass imported result instead of re-importing
String signatureName = "serving_default";
-
assertEqualResult(modelDir, signatureName, "Variable/read");
assertEqualResult(modelDir, signatureName, "Variable_1/read");
// TODO: Assert that argument fed is as expected assertEqualResult(modelDir, signatureName, "Placeholder");
@@ -80,7 +87,7 @@ public class Mnist_SoftmaxTestCase {
Context context = contextFrom(result);
Tensor placeholder = placeholderArgument();
context.put("Placeholder", new TensorValue(placeholder));
- Tensor vespaResult = result.expressions().get(0).evaluate(context).asTensor();
+ Tensor vespaResult = result.signatures().get(signatureName).outputs().get(operationName).evaluate(context).asTensor();
assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult);
}