summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-12-20 15:56:44 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-12-20 15:56:44 +0100
commit82ff058a65bb3ae4b195906a8f927d00846b7fcc (patch)
tree1717189e0a6d51996b13b6e782a6008d73d71c99 /searchlib
parentf4e87644d6ee9801786aa3cc9101bce162cfe55a (diff)
Only import once
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java29
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java32
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java31
3 files changed, 47 insertions, 45 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
index b3c1708a0f4..947e6d7a5e1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
@@ -25,10 +25,12 @@ public class ImportResult {
private final Map<String, Signature> signatures = new HashMap<>();
private final Map<String, TensorType> arguments = new HashMap<>();
private final Map<String, Tensor> constants = new HashMap<>();
+ private final Map<String, RankingExpression> expressions = new HashMap<>();
private final List<String> warnings = new ArrayList<>();
void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
void constant(String name, Tensor constant) { constants.put(name, constant); }
+ void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
void warn(String warning) { warnings.add(warning); }
/** Returns the given signature. If it does not already exist it is added to this. */
@@ -42,31 +44,37 @@ public class ImportResult {
/** Returns an immutable map of the constants of this */
public Map<String, Tensor> constants() { return Collections.unmodifiableMap(constants); }
- /** Returns an immutable map of the signatures of this */
- public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
+ /**
+ * Returns an immutable map of the expressions of this - corresponding to TensorFlow nodes
+ * which are not Placeholders or Variables (which instead become respectively arguments and constants).
+ * Note that only nodes recursively referenced by a placeholder are added.
+ */
+ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
/** Returns an immutable list, in natural sort order of the warnings generated while importing this */
public List<String> warnings() {
return warnings.stream().sorted().collect(Collectors.toList());
}
+ /** Returns an immutable map of the signatures of this */
+ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
+
/**
* A signature is a set of named inputs and outputs, where the inputs maps to argument ("placeholder") names+types,
- * and outputs maps to ranking expressions stemming from conversion of TensorFlow nodes and the inputs make up the
- * context which is needed to evaluate the expression.
+ * and outputs maps to expressions nodes.
*/
public class Signature {
private final String name;
private final Map<String, String> inputs = new HashMap<>();
- private final Map<String, RankingExpression> outputs = new HashMap<>();
+ private final Map<String, String> outputs = new HashMap<>();
Signature(String name) {
this.name = name;
}
void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
- void output(String name, RankingExpression expression) { outputs.put(name, expression); }
+ void output(String name, String expressionName) { outputs.put(name, expressionName); }
/** Returns the result this is part of */
ImportResult owner() { return ImportResult.this; }
@@ -78,10 +86,13 @@ public class ImportResult {
public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
/** Returns owner().arguments().get(inputs.get(name)), e.g the type of the argument this input references */
- public TensorType inputType(String inputName) { return owner().arguments().get(inputs.get(inputName)); }
+ public TensorType inputArgument(String inputName) { return owner().arguments().get(inputs.get(inputName)); }
+
+ /** Returns an immutable list of the expression names of this */
+ public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); }
- /** Returns an immutable list of the expressions of this */
- public Map<String, RankingExpression> outputs() { return Collections.unmodifiableMap(outputs); }
+ /** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */
+ public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); }
@Override
public String toString() { return "signature '" + name + "'"; }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 0f38ee0b0a4..cf72a327ba0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -38,27 +38,20 @@ public class TensorFlowImporter {
*/
public ImportResult importModel(String modelDir) {
try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
- return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model);
+ return importModel(model);
}
- catch (IOException e) {
- throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e);
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
}
}
- /** Imports a specific node as an putput given the name of that node. Useful for testing */
- public ImportResult importNode(String modelDir, String signatureName, String nodeName) {
- try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
- MetaGraphDef graph = MetaGraphDef.parseFrom(model.metaGraphDef());
- SignatureDef signatureDef = graph.getSignatureDefMap().get(signatureName);
- ImportResult result = new ImportResult();
- ImportResult.Signature signature = result.signature(signatureName);
- importInputs(signatureDef.getInputsMap(), signature);
- signature.output(nodeName,
- new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, result)));
- return result;
+ /** Imports a TensorFlow model */
+ public ImportResult importModel(SavedModelBundle model) {
+ try {
+ return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model);
}
catch (IOException e) {
- throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e);
+ throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
}
}
@@ -72,7 +65,7 @@ public class TensorFlowImporter {
String outputName = output.getKey();
try {
ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, result);
- signature.output(outputName, new RankingExpression(outputName, node));
+ signature.output(outputName, nameOf(output.getValue().getName()));
}
catch (IllegalArgumentException e) {
result.warn("Skipping output '" + outputName + "' of " + signature +
@@ -111,12 +104,15 @@ public class TensorFlowImporter {
private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, ImportResult result) {
TensorFunction function = importNode(getNode(nodeName, graph), graph, model, result).function();
- return new TensorFunctionNode(function); // wrap top level (only) as an expression
+ result.expression(nodeName, new RankingExpression(nodeName, new TensorFunctionNode(function)));
+ return new TensorFunctionNode(function); // wrap top level (only) as an expression // TODO: waht to return
}
/** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
- return tensorFunctionOf(tfNode, graph, model, result);
+ TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result);
+ result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), new TensorFunctionNode(function.function())));
+ return function; // TODO: waht to return
}
private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
index 53989af4460..0370fc7fc94 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
@@ -27,7 +27,8 @@ public class Mnist_SoftmaxTestCase {
@Test
public void testImporting() {
String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved";
- ImportResult result = new TensorFlowImporter().importModel(modelDir);
+ SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
+ ImportResult result = new TensorFlowImporter().importModel(model);
// Check logged messages
result.warnings().forEach(System.err::println);
@@ -55,15 +56,15 @@ public class Mnist_SoftmaxTestCase {
// ... signature inputs
assertEquals(1, signature.inputs().size());
- TensorType argument0 = signature.inputType("x");
+ TensorType argument0 = signature.inputArgument("x");
assertNotNull(argument0);
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);
// ... signature outputs
assertEquals(1, signature.outputs().size());
- RankingExpression output = signature.outputs().get("y");
+ RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("y", output.getName());
+ assertEquals("add", output.getName());
assertEquals("" +
"join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " +
"rename(constant(Variable_1), d0, d1), " +
@@ -71,28 +72,22 @@ public class Mnist_SoftmaxTestCase {
toNonPrimitiveString(output));
// Test execution
- // TODO: Pass imported result instead of re-importing
- String signatureName = "serving_default";
- assertEqualResult(modelDir, signatureName, "Variable/read");
- assertEqualResult(modelDir, signatureName, "Variable_1/read");
- // TODO: Assert that argument fed is as expected assertEqualResult(modelDir, signatureName, "Placeholder");
- assertEqualResult(modelDir, signatureName, "MatMul");
- assertEqualResult(modelDir, signatureName, "add");
+ assertEqualResult(model, result, "Variable/read");
+ assertEqualResult(model, result, "Variable_1/read");
+ assertEqualResult(model, result, "MatMul");
+ assertEqualResult(model, result, "add");
}
- private void assertEqualResult(String modelDir, String signatureName, String operationName) {
- ImportResult result = new TensorFlowImporter().importNode(modelDir, signatureName, operationName);
-
- Tensor tfResult = tensorFlowExecute(modelDir, operationName);
+ private void assertEqualResult(SavedModelBundle model, ImportResult result, String operationName) {
+ Tensor tfResult = tensorFlowExecute(model, operationName);
Context context = contextFrom(result);
Tensor placeholder = placeholderArgument();
context.put("Placeholder", new TensorValue(placeholder));
- Tensor vespaResult = result.signatures().get(signatureName).outputs().get(operationName).evaluate(context).asTensor();
+ Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor();
assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult);
}
- private Tensor tensorFlowExecute(String modelDir, String operationName) {
- SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
+ private Tensor tensorFlowExecute(SavedModelBundle model, String operationName) {
Session.Runner runner = model.session().runner();
org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784));
runner.feed("Placeholder", placeholder);