diff options
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java | 71 |
1 files changed, 43 insertions, 28 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 33523244129..66bfe9fcfb9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -45,13 +45,16 @@ public class TensorFlowImporter { } } - public ImportResult importNode(String modelDir, String inputSignatureName, String nodeName) { + /** Imports a specific node as an putput given the name of that node. Useful for testing */ + public ImportResult importNode(String modelDir, String signatureName, String nodeName) { try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { MetaGraphDef graph = MetaGraphDef.parseFrom(model.metaGraphDef()); - SignatureDef signature = graph.getSignatureDefMap().get(inputSignatureName); + SignatureDef signatureDef = graph.getSignatureDefMap().get(signatureName); ImportResult result = new ImportResult(); - importInputs(signature.getInputsMap(), result); - result.add(new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, result))); + ImportResult.Signature signature = result.signature(signatureName); + importInputs(signatureDef.getInputsMap(), signature); + signature.output(nodeName, + new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, signature))); return result; } catch (IOException e) { @@ -62,25 +65,32 @@ public class TensorFlowImporter { private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) { ImportResult result = new ImportResult(); for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) { - importInputs(signatureEntry.getValue().getInputsMap(), result); + ImportResult.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName" + + importInputs(signatureEntry.getValue().getInputsMap(), signature); for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) { + String outputName = output.getKey(); try { - ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, result); - result.add(new RankingExpression(output.getKey(), node)); + ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, signature); + signature.output(outputName, new RankingExpression(outputName, node)); } catch (IllegalArgumentException e) { - result.warn("Skipping output '" + output.getValue().getName() + "' of signature '" + - signatureEntry.getValue().getMethodName() + - "': " + Exceptions.toMessageString(e)); + result.warn("Skipping output '" + outputName + "' of " + signature + + ": " + Exceptions.toMessageString(e)); } } } return result; } - private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult result) { - inputInfoMap.forEach((key, value) -> result.set(nameOf(value.getName()), - importTensorType(value.getTensorShape()))); + private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult.Signature signature) { + inputInfoMap.forEach((key, value) -> { + String argumentName = nameOf(value.getName()); + TensorType argumentType = importTensorType(value.getTensorShape()); + // Arguments are (Placeholder) nodes, so not local to the signature: + signature.owner().argument(argumentName, argumentType); + signature.input(key, argumentName); + }); } private TensorType importTensorType(TensorShapeProto tensorShape) { @@ -95,37 +105,42 @@ public class TensorFlowImporter { return b.build(); } - private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, ImportResult result) { - return importNode(nameOf(output.getName()), graph, model, result); + private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, + ImportResult.Signature signature) { + return importNode(nameOf(output.getName()), graph, model, signature); } - private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, ImportResult result) { - TensorFunction function = importNode(getNode(nodeName, graph), graph, model, result).function(); + private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, + ImportResult.Signature signature) { + TensorFunction function = importNode(getNode(nodeName, graph), graph, model, signature).function(); return new TensorFunctionNode(function); // wrap top level (only) as an expression } /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ - private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { - return tensorFunctionOf(tfNode, graph, model, result); + private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, + ImportResult.Signature signature) { + return tensorFunctionOf(tfNode, graph, model, signature); } - private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { + private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, + ImportResult.Signature signature) { // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/ switch (tfNode.getOp().toLowerCase()) { - case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add()); - case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.acos()); - case "placeholder" : return operationMapper.placeholder(tfNode, result); - case "identity" : return operationMapper.identity(tfNode, model, result); - case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, result)); - case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, result)); + case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, signature), ScalarFunctions.add()); + case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, signature), ScalarFunctions.acos()); + case "placeholder" : return operationMapper.placeholder(tfNode, signature); + case "identity" : return operationMapper.identity(tfNode, model, signature); + case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, signature)); + case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, signature)); default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported"); } } - private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { + private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, + ImportResult.Signature signature) { return tfNode.getInputList().stream() - .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result)) + .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, signature)) .collect(Collectors.toList()); } |