diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-12-20 15:25:20 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-12-20 15:25:20 +0100 |
commit | f4e87644d6ee9801786aa3cc9101bce162cfe55a (patch) | |
tree | f30bf95eb7231ff6b1ab7f8be7b6385457c1e257 /searchlib | |
parent | 5dce0c978c36936c7372e32d1a05f05c0b61386e (diff) |
Simplify
Diffstat (limited to 'searchlib')
2 files changed, 21 insertions, 25 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index 5e2d7530200..bac141644c6 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -86,9 +86,9 @@ class OperationMapper { return new TypedTensorFunction(resultType, function); } - TypedTensorFunction placeholder(NodeDef tfNode, ImportResult.Signature signature) { + TypedTensorFunction placeholder(NodeDef tfNode, ImportResult result) { String name = tfNode.getName(); - TensorType type = signature.owner().arguments().get(name); + TensorType type = result.arguments().get(name); if (type == null) throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name + "', but there is no such placeholder"); @@ -96,7 +96,7 @@ class OperationMapper { return new TypedTensorFunction(type, new VariableTensor(name)); } - TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult.Signature signature) { + TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult result) { if ( ! tfNode.getName().endsWith("/read")) throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " + "nodes are only supported when reading variables"); @@ -114,7 +114,7 @@ class OperationMapper { throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " + importedTensors.size()); Tensor constant = tensorConverter.toVespaTensor(importedTensors.get(0)); - signature.owner().constant(name, constant); + result.constant(name, constant); return new TypedTensorFunction(constant.type(), new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")"))); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 66bfe9fcfb9..0f38ee0b0a4 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -54,7 +54,7 @@ public class TensorFlowImporter { ImportResult.Signature signature = result.signature(signatureName); importInputs(signatureDef.getInputsMap(), signature); signature.output(nodeName, - new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, signature))); + new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, result))); return result; } catch (IOException e) { @@ -71,7 +71,7 @@ public class TensorFlowImporter { for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) { String outputName = output.getKey(); try { - ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, signature); + ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, result); signature.output(outputName, new RankingExpression(outputName, node)); } catch (IllegalArgumentException e) { @@ -105,42 +105,38 @@ public class TensorFlowImporter { return b.build(); } - private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, - ImportResult.Signature signature) { - return importNode(nameOf(output.getName()), graph, model, signature); + private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, ImportResult result) { + return importNode(nameOf(output.getName()), graph, model, result); } - private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, - ImportResult.Signature signature) { - TensorFunction function = importNode(getNode(nodeName, graph), graph, model, signature).function(); + private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, ImportResult result) { + TensorFunction function = importNode(getNode(nodeName, graph), graph, model, result).function(); return new TensorFunctionNode(function); // wrap top level (only) as an expression } /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ - private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, - ImportResult.Signature signature) { - return tensorFunctionOf(tfNode, graph, model, signature); + private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { + return tensorFunctionOf(tfNode, graph, model, result); } - private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, - ImportResult.Signature signature) { + private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/ switch (tfNode.getOp().toLowerCase()) { - case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, signature), ScalarFunctions.add()); - case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, signature), ScalarFunctions.acos()); - case "placeholder" : return operationMapper.placeholder(tfNode, signature); - case "identity" : return operationMapper.identity(tfNode, model, signature); - case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, signature)); - case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, signature)); + case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add()); + case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.acos()); + case "placeholder" : return operationMapper.placeholder(tfNode, result); + case "identity" : return operationMapper.identity(tfNode, model, result); + case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, result)); + case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, result)); default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported"); } } private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, - ImportResult.Signature signature) { + ImportResult result) { return tfNode.getInputList().stream() - .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, signature)) + .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result)) .collect(Collectors.toList()); } |