aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java71
1 files changed, 43 insertions, 28 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 33523244129..66bfe9fcfb9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -45,13 +45,16 @@ public class TensorFlowImporter {
}
}
- public ImportResult importNode(String modelDir, String inputSignatureName, String nodeName) {
+ /** Imports a specific node as an putput given the name of that node. Useful for testing */
+ public ImportResult importNode(String modelDir, String signatureName, String nodeName) {
try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
MetaGraphDef graph = MetaGraphDef.parseFrom(model.metaGraphDef());
- SignatureDef signature = graph.getSignatureDefMap().get(inputSignatureName);
+ SignatureDef signatureDef = graph.getSignatureDefMap().get(signatureName);
ImportResult result = new ImportResult();
- importInputs(signature.getInputsMap(), result);
- result.add(new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, result)));
+ ImportResult.Signature signature = result.signature(signatureName);
+ importInputs(signatureDef.getInputsMap(), signature);
+ signature.output(nodeName,
+ new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, signature)));
return result;
}
catch (IOException e) {
@@ -62,25 +65,32 @@ public class TensorFlowImporter {
private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) {
ImportResult result = new ImportResult();
for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
- importInputs(signatureEntry.getValue().getInputsMap(), result);
+ ImportResult.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName"
+
+ importInputs(signatureEntry.getValue().getInputsMap(), signature);
for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) {
+ String outputName = output.getKey();
try {
- ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, result);
- result.add(new RankingExpression(output.getKey(), node));
+ ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, signature);
+ signature.output(outputName, new RankingExpression(outputName, node));
}
catch (IllegalArgumentException e) {
- result.warn("Skipping output '" + output.getValue().getName() + "' of signature '" +
- signatureEntry.getValue().getMethodName() +
- "': " + Exceptions.toMessageString(e));
+ result.warn("Skipping output '" + outputName + "' of " + signature +
+ ": " + Exceptions.toMessageString(e));
}
}
}
return result;
}
- private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult result) {
- inputInfoMap.forEach((key, value) -> result.set(nameOf(value.getName()),
- importTensorType(value.getTensorShape())));
+ private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult.Signature signature) {
+ inputInfoMap.forEach((key, value) -> {
+ String argumentName = nameOf(value.getName());
+ TensorType argumentType = importTensorType(value.getTensorShape());
+ // Arguments are (Placeholder) nodes, so not local to the signature:
+ signature.owner().argument(argumentName, argumentType);
+ signature.input(key, argumentName);
+ });
}
private TensorType importTensorType(TensorShapeProto tensorShape) {
@@ -95,37 +105,42 @@ public class TensorFlowImporter {
return b.build();
}
- private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, ImportResult result) {
- return importNode(nameOf(output.getName()), graph, model, result);
+ private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model,
+ ImportResult.Signature signature) {
+ return importNode(nameOf(output.getName()), graph, model, signature);
}
- private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, ImportResult result) {
- TensorFunction function = importNode(getNode(nodeName, graph), graph, model, result).function();
+ private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model,
+ ImportResult.Signature signature) {
+ TensorFunction function = importNode(getNode(nodeName, graph), graph, model, signature).function();
return new TensorFunctionNode(function); // wrap top level (only) as an expression
}
/** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
- private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
- return tensorFunctionOf(tfNode, graph, model, result);
+ private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model,
+ ImportResult.Signature signature) {
+ return tensorFunctionOf(tfNode, graph, model, signature);
}
- private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model,
+ ImportResult.Signature signature) {
// Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops
// TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/
switch (tfNode.getOp().toLowerCase()) {
- case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add());
- case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.acos());
- case "placeholder" : return operationMapper.placeholder(tfNode, result);
- case "identity" : return operationMapper.identity(tfNode, model, result);
- case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, result));
- case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, result));
+ case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, signature), ScalarFunctions.add());
+ case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, signature), ScalarFunctions.acos());
+ case "placeholder" : return operationMapper.placeholder(tfNode, signature);
+ case "identity" : return operationMapper.identity(tfNode, model, signature);
+ case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, signature));
+ case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, signature));
default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported");
}
}
- private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model,
+ ImportResult.Signature signature) {
return tfNode.getInputList().stream()
- .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result))
+ .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, signature))
.collect(Collectors.toList());
}