diff options
Diffstat (limited to 'searchlib')
3 files changed, 47 insertions, 45 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java index b3c1708a0f4..947e6d7a5e1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java @@ -25,10 +25,12 @@ public class ImportResult { private final Map<String, Signature> signatures = new HashMap<>(); private final Map<String, TensorType> arguments = new HashMap<>(); private final Map<String, Tensor> constants = new HashMap<>(); + private final Map<String, RankingExpression> expressions = new HashMap<>(); private final List<String> warnings = new ArrayList<>(); void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } void constant(String name, Tensor constant) { constants.put(name, constant); } + void expression(String name, RankingExpression expression) { expressions.put(name, expression); } void warn(String warning) { warnings.add(warning); } /** Returns the given signature. If it does not already exist it is added to this. */ @@ -42,31 +44,37 @@ public class ImportResult { /** Returns an immutable map of the constants of this */ public Map<String, Tensor> constants() { return Collections.unmodifiableMap(constants); } - /** Returns an immutable map of the signatures of this */ - public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } + /** + * Returns an immutable map of the expressions of this - corresponding to TensorFlow nodes + * which are not Placeholders or Variables (which instead become respectively arguments and constants). + * Note that only nodes recursively referenced by a placeholder are added. + */ + public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } /** Returns an immutable list, in natural sort order of the warnings generated while importing this */ public List<String> warnings() { return warnings.stream().sorted().collect(Collectors.toList()); } + /** Returns an immutable map of the signatures of this */ + public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } + /** * A signature is a set of named inputs and outputs, where the inputs maps to argument ("placeholder") names+types, - * and outputs maps to ranking expressions stemming from conversion of TensorFlow nodes and the inputs make up the - * context which is needed to evaluate the expression. + * and outputs maps to expressions nodes. */ public class Signature { private final String name; private final Map<String, String> inputs = new HashMap<>(); - private final Map<String, RankingExpression> outputs = new HashMap<>(); + private final Map<String, String> outputs = new HashMap<>(); Signature(String name) { this.name = name; } void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } - void output(String name, RankingExpression expression) { outputs.put(name, expression); } + void output(String name, String expressionName) { outputs.put(name, expressionName); } /** Returns the result this is part of */ ImportResult owner() { return ImportResult.this; } @@ -78,10 +86,13 @@ public class ImportResult { public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); } /** Returns owner().arguments().get(inputs.get(name)), e.g the type of the argument this input references */ - public TensorType inputType(String inputName) { return owner().arguments().get(inputs.get(inputName)); } + public TensorType inputArgument(String inputName) { return owner().arguments().get(inputs.get(inputName)); } + + /** Returns an immutable list of the expression names of this */ + public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); } - /** Returns an immutable list of the expressions of this */ - public Map<String, RankingExpression> outputs() { return Collections.unmodifiableMap(outputs); } + /** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */ + public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); } @Override public String toString() { return "signature '" + name + "'"; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 0f38ee0b0a4..cf72a327ba0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -38,27 +38,20 @@ public class TensorFlowImporter { */ public ImportResult importModel(String modelDir) { try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { - return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model); + return importModel(model); } - catch (IOException e) { - throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e); + catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); } } - /** Imports a specific node as an putput given the name of that node. Useful for testing */ - public ImportResult importNode(String modelDir, String signatureName, String nodeName) { - try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { - MetaGraphDef graph = MetaGraphDef.parseFrom(model.metaGraphDef()); - SignatureDef signatureDef = graph.getSignatureDefMap().get(signatureName); - ImportResult result = new ImportResult(); - ImportResult.Signature signature = result.signature(signatureName); - importInputs(signatureDef.getInputsMap(), signature); - signature.output(nodeName, - new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, result))); - return result; + /** Imports a TensorFlow model */ + public ImportResult importModel(SavedModelBundle model) { + try { + return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model); } catch (IOException e) { - throw new IllegalArgumentException("Could not read TensorFlow model from directory '" + modelDir + "'", e); + throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e); } } @@ -72,7 +65,7 @@ public class TensorFlowImporter { String outputName = output.getKey(); try { ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, result); - signature.output(outputName, new RankingExpression(outputName, node)); + signature.output(outputName, nameOf(output.getValue().getName())); } catch (IllegalArgumentException e) { result.warn("Skipping output '" + outputName + "' of " + signature + @@ -111,12 +104,15 @@ public class TensorFlowImporter { private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, ImportResult result) { TensorFunction function = importNode(getNode(nodeName, graph), graph, model, result).function(); - return new TensorFunctionNode(function); // wrap top level (only) as an expression + result.expression(nodeName, new RankingExpression(nodeName, new TensorFunctionNode(function))); + return new TensorFunctionNode(function); // wrap top level (only) as an expression // TODO: waht to return } /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { - return tensorFunctionOf(tfNode, graph, model, result); + TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result); + result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), new TensorFunctionNode(function.function()))); + return function; // TODO: waht to return } private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java index 53989af4460..0370fc7fc94 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java @@ -27,7 +27,8 @@ public class Mnist_SoftmaxTestCase { @Test public void testImporting() { String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved"; - ImportResult result = new TensorFlowImporter().importModel(modelDir); + SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); + ImportResult result = new TensorFlowImporter().importModel(model); // Check logged messages result.warnings().forEach(System.err::println); @@ -55,15 +56,15 @@ public class Mnist_SoftmaxTestCase { // ... signature inputs assertEquals(1, signature.inputs().size()); - TensorType argument0 = signature.inputType("x"); + TensorType argument0 = signature.inputArgument("x"); assertNotNull(argument0); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0); // ... signature outputs assertEquals(1, signature.outputs().size()); - RankingExpression output = signature.outputs().get("y"); + RankingExpression output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("y", output.getName()); + assertEquals("add", output.getName()); assertEquals("" + "join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " + "rename(constant(Variable_1), d0, d1), " + @@ -71,28 +72,22 @@ public class Mnist_SoftmaxTestCase { toNonPrimitiveString(output)); // Test execution - // TODO: Pass imported result instead of re-importing - String signatureName = "serving_default"; - assertEqualResult(modelDir, signatureName, "Variable/read"); - assertEqualResult(modelDir, signatureName, "Variable_1/read"); - // TODO: Assert that argument fed is as expected assertEqualResult(modelDir, signatureName, "Placeholder"); - assertEqualResult(modelDir, signatureName, "MatMul"); - assertEqualResult(modelDir, signatureName, "add"); + assertEqualResult(model, result, "Variable/read"); + assertEqualResult(model, result, "Variable_1/read"); + assertEqualResult(model, result, "MatMul"); + assertEqualResult(model, result, "add"); } - private void assertEqualResult(String modelDir, String signatureName, String operationName) { - ImportResult result = new TensorFlowImporter().importNode(modelDir, signatureName, operationName); - - Tensor tfResult = tensorFlowExecute(modelDir, operationName); + private void assertEqualResult(SavedModelBundle model, ImportResult result, String operationName) { + Tensor tfResult = tensorFlowExecute(model, operationName); Context context = contextFrom(result); Tensor placeholder = placeholderArgument(); context.put("Placeholder", new TensorValue(placeholder)); - Tensor vespaResult = result.signatures().get(signatureName).outputs().get(operationName).evaluate(context).asTensor(); + Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor(); assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult); } - private Tensor tensorFlowExecute(String modelDir, String operationName) { - SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); + private Tensor tensorFlowExecute(SavedModelBundle model, String operationName) { Session.Runner runner = model.session().runner(); org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784)); runner.feed("Placeholder", placeholder); |