summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2017-12-21 08:54:11 +0100
committerGitHub <noreply@github.com>2017-12-21 08:54:11 +0100
commita23acc0d0f9f0f6c651bcf352bd2cfed6d33debc (patch)
tree3faba807627773ac594d7b90c45cebe64b44af3f /searchlib
parent431f9496782fa7d3d9c35abfcde2a2c1b3c55621 (diff)
parentd97e6b0a72a7c7d38365475637bff3897bc7597a (diff)
Merge pull request #4509 from vespa-engine/bratseth/tensorflow-models-2
Bratseth/tensorflow models 2
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java81
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java66
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java52
4 files changed, 128 insertions, 77 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
index b4a9b363ade..947e6d7a5e1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
@@ -12,40 +12,91 @@ import java.util.Map;
import java.util.stream.Collectors;
/**
- * The result of importing a TensorFlow model into Vespa:
- * - A list of ranking expressions reproducing the computations of the outputs in the TensorFlow model
- * - A list of named constant tensors
- * - A list of expected input tensors, with their tensor type
- * - A list of warning messages
+ * The result of importing a TensorFlow model into Vespa.
+ * - A set of signatures which are named collections of inputs and outputs.
+ * - A set of named constant tensors represented by Variable nodes in TensorFlow.
+ * - A list of warning messages.
*
* @author bratseth
*/
// This object can be built incrementally within this package, but is immutable when observed from outside the package
-// TODO: Retain signature structure in ImportResult (input + output-expression bundles)
public class ImportResult {
- private final List<RankingExpression> expressions = new ArrayList<>();
- private final Map<String, Tensor> constants = new HashMap<>();
+ private final Map<String, Signature> signatures = new HashMap<>();
private final Map<String, TensorType> arguments = new HashMap<>();
+ private final Map<String, Tensor> constants = new HashMap<>();
+ private final Map<String, RankingExpression> expressions = new HashMap<>();
private final List<String> warnings = new ArrayList<>();
- void add(RankingExpression expression) { expressions.add(expression); }
- void set(String name, Tensor constant) { constants.put(name, constant); }
- void set(String name, TensorType argument) { arguments.put(name, argument); }
+ void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
+ void constant(String name, Tensor constant) { constants.put(name, constant); }
+ void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
void warn(String warning) { warnings.add(warning); }
- /** Returns an immutable list of the expressions of this */
- public List<RankingExpression> expressions() { return Collections.unmodifiableList(expressions); }
+ /** Returns the given signature. If it does not already exist it is added to this. */
+ Signature signature(String name) {
+ return signatures.computeIfAbsent(name, n -> new Signature(n));
+ }
+
+ /** Returns an immutable map of the arguments ("Placeholders") of this */
+ public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
/** Returns an immutable map of the constants of this */
public Map<String, Tensor> constants() { return Collections.unmodifiableMap(constants); }
- /** Returns an immutable map of the arguments of this */
- public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
+ /**
+ * Returns an immutable map of the expressions of this - corresponding to TensorFlow nodes
+ * which are not Placeholders or Variables (which instead become respectively arguments and constants).
+ * Note that only nodes recursively referenced by a placeholder are added.
+ */
+ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
/** Returns an immutable list, in natural sort order of the warnings generated while importing this */
public List<String> warnings() {
return warnings.stream().sorted().collect(Collectors.toList());
}
+ /** Returns an immutable map of the signatures of this */
+ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
+
+ /**
+ * A signature is a set of named inputs and outputs, where the inputs maps to argument ("placeholder") names+types,
+ * and outputs maps to expressions nodes.
+ */
+ public class Signature {
+
+ private final String name;
+ private final Map<String, String> inputs = new HashMap<>();
+ private final Map<String, String> outputs = new HashMap<>();
+
+ Signature(String name) {
+ this.name = name;
+ }
+
+ void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
+ void output(String name, String expressionName) { outputs.put(name, expressionName); }
+
+ /** Returns the result this is part of */
+ ImportResult owner() { return ImportResult.this; }
+
+ /**
+ * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
+ * to argument (Placeholder) name in the owner of this
+ */
+ public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
+
+ /** Returns owner().arguments().get(inputs.get(name)), e.g the type of the argument this input references */
+ public TensorType inputArgument(String inputName) { return owner().arguments().get(inputs.get(inputName)); }
+
+ /** Returns an immutable list of the expression names of this */
+ public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); }
+
+ /** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */
+ public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); }
+
+ @Override
+ public String toString() { return "signature '" + name + "'"; }
+
+ }
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index e7f7b5ef2f4..bac141644c6 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -90,8 +90,8 @@ class OperationMapper {
String name = tfNode.getName();
TensorType type = result.arguments().get(name);
if (type == null)
- throw new IllegalArgumentException("An placeholder operation node is referencing input '" + name +
- "', but there is no such input");
+ throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name +
+ "', but there is no such placeholder");
// Included literally in the expression and so must be produced by a separate macro in the rank profile
return new TypedTensorFunction(type, new VariableTensor(name));
}
@@ -114,7 +114,7 @@ class OperationMapper {
throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " +
importedTensors.size());
Tensor constant = tensorConverter.toVespaTensor(importedTensors.get(0));
- result.set(name, constant);
+ result.constant(name, constant);
return new TypedTensorFunction(constant.type(),
new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")")));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 33523244129..4a6551adca7 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -1,11 +1,9 @@
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.ScalarFunctions;
-import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.yolean.Exceptions;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.framework.GraphDef;
@@ -38,49 +36,53 @@ public class TensorFlowImporter {
*/
public ImportResult importModel(String modelDir) {
try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
- return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model);
+ return importModel(model);
}
- catch (IOException e) {
- throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e);
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
}
}
- public ImportResult importNode(String modelDir, String inputSignatureName, String nodeName) {
- try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
- MetaGraphDef graph = MetaGraphDef.parseFrom(model.metaGraphDef());
- SignatureDef signature = graph.getSignatureDefMap().get(inputSignatureName);
- ImportResult result = new ImportResult();
- importInputs(signature.getInputsMap(), result);
- result.add(new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, result)));
- return result;
+ /** Imports a TensorFlow model */
+ public ImportResult importModel(SavedModelBundle model) {
+ try {
+ return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model);
}
catch (IOException e) {
- throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e);
+ throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
}
}
private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) {
ImportResult result = new ImportResult();
for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
- importInputs(signatureEntry.getValue().getInputsMap(), result);
+ ImportResult.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName"
+
+ importInputs(signatureEntry.getValue().getInputsMap(), signature);
for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) {
+ String outputName = output.getKey();
try {
- ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, result);
- result.add(new RankingExpression(output.getKey(), node));
+ NodeDef node = getNode(nameOf(output.getValue().getName()), graph.getGraphDef());
+ importNode(node, graph.getGraphDef(), model, result);
+ signature.output(outputName, nameOf(output.getValue().getName()));
}
catch (IllegalArgumentException e) {
- result.warn("Skipping output '" + output.getValue().getName() + "' of signature '" +
- signatureEntry.getValue().getMethodName() +
- "': " + Exceptions.toMessageString(e));
+ result.warn("Skipping output '" + outputName + "' of " + signature +
+ ": " + Exceptions.toMessageString(e));
}
}
}
return result;
}
- private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult result) {
- inputInfoMap.forEach((key, value) -> result.set(nameOf(value.getName()),
- importTensorType(value.getTensorShape())));
+ private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult.Signature signature) {
+ inputInfoMap.forEach((key, value) -> {
+ String argumentName = nameOf(value.getName());
+ TensorType argumentType = importTensorType(value.getTensorShape());
+ // Arguments are (Placeholder) nodes, so not local to the signature:
+ signature.owner().argument(argumentName, argumentType);
+ signature.input(key, argumentName);
+ });
}
private TensorType importTensorType(TensorShapeProto tensorShape) {
@@ -95,18 +97,13 @@ public class TensorFlowImporter {
return b.build();
}
- private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, ImportResult result) {
- return importNode(nameOf(output.getName()), graph, model, result);
- }
-
- private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, ImportResult result) {
- TensorFunction function = importNode(getNode(nodeName, graph), graph, model, result).function();
- return new TensorFunctionNode(function); // wrap top level (only) as an expression
- }
-
/** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
- return tensorFunctionOf(tfNode, graph, model, result);
+ TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result);
+ // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output
+ // will be used
+ result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), new TensorFunctionNode(function.function())));
+ return function;
}
private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
@@ -123,7 +120,8 @@ public class TensorFlowImporter {
}
}
- private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model,
+ ImportResult result) {
return tfNode.getInputList().stream()
.map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result))
.collect(Collectors.toList());
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
index d50a97cc8e0..0370fc7fc94 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
@@ -27,18 +27,13 @@ public class Mnist_SoftmaxTestCase {
@Test
public void testImporting() {
String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved";
- ImportResult result = new TensorFlowImporter().importModel(modelDir);
+ SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
+ ImportResult result = new TensorFlowImporter().importModel(model);
// Check logged messages
result.warnings().forEach(System.err::println);
assertEquals(0, result.warnings().size());
- // Check arguments
- assertEquals(1, result.arguments().size());
- TensorType argument0 = result.arguments().get("Placeholder");
- assertNotNull(argument0);
- assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);
-
// Check constants
assertEquals(2, result.constants().size());
@@ -54,38 +49,45 @@ public class Mnist_SoftmaxTestCase {
constant1.type());
assertEquals(10, constant1.size());
- // Check resulting Vespa expression
- assertEquals(1, result.expressions().size());
- assertEquals("y", result.expressions().get(0).getName());
+ // Check signatures
+ assertEquals(1, result.signatures().size());
+ ImportResult.Signature signature = result.signatures().get("serving_default");
+ assertNotNull(signature);
+
+ // ... signature inputs
+ assertEquals(1, signature.inputs().size());
+ TensorType argument0 = signature.inputArgument("x");
+ assertNotNull(argument0);
+ assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);
+
+ // ... signature outputs
+ assertEquals(1, signature.outputs().size());
+ RankingExpression output = signature.outputExpression("y");
+ assertNotNull(output);
+ assertEquals("add", output.getName());
assertEquals("" +
"join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " +
"rename(constant(Variable_1), d0, d1), " +
"f(a,b)(a + b))",
- toNonPrimitiveString(result.expressions().get(0)));
+ toNonPrimitiveString(output));
// Test execution
- String signatureName = "serving_default";
-
- assertEqualResult(modelDir, signatureName, "Variable/read");
- assertEqualResult(modelDir, signatureName, "Variable_1/read");
- // TODO: Assert that argument fed is as expected assertEqualResult(modelDir, signatureName, "Placeholder");
- assertEqualResult(modelDir, signatureName, "MatMul");
- assertEqualResult(modelDir, signatureName, "add");
+ assertEqualResult(model, result, "Variable/read");
+ assertEqualResult(model, result, "Variable_1/read");
+ assertEqualResult(model, result, "MatMul");
+ assertEqualResult(model, result, "add");
}
- private void assertEqualResult(String modelDir, String signatureName, String operationName) {
- ImportResult result = new TensorFlowImporter().importNode(modelDir, signatureName, operationName);
-
- Tensor tfResult = tensorFlowExecute(modelDir, operationName);
+ private void assertEqualResult(SavedModelBundle model, ImportResult result, String operationName) {
+ Tensor tfResult = tensorFlowExecute(model, operationName);
Context context = contextFrom(result);
Tensor placeholder = placeholderArgument();
context.put("Placeholder", new TensorValue(placeholder));
- Tensor vespaResult = result.expressions().get(0).evaluate(context).asTensor();
+ Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor();
assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult);
}
- private Tensor tensorFlowExecute(String modelDir, String operationName) {
- SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
+ private Tensor tensorFlowExecute(SavedModelBundle model, String operationName) {
Session.Runner runner = model.session().runner();
org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784));
runner.feed("Placeholder", placeholder);