summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-12-20 15:25:20 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-12-20 15:25:20 +0100
commitf4e87644d6ee9801786aa3cc9101bce162cfe55a (patch)
treef30bf95eb7231ff6b1ab7f8be7b6385457c1e257 /searchlib
parent5dce0c978c36936c7372e32d1a05f05c0b61386e (diff)
Simplify
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java38
2 files changed, 21 insertions, 25 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index 5e2d7530200..bac141644c6 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -86,9 +86,9 @@ class OperationMapper {
return new TypedTensorFunction(resultType, function);
}
- TypedTensorFunction placeholder(NodeDef tfNode, ImportResult.Signature signature) {
+ TypedTensorFunction placeholder(NodeDef tfNode, ImportResult result) {
String name = tfNode.getName();
- TensorType type = signature.owner().arguments().get(name);
+ TensorType type = result.arguments().get(name);
if (type == null)
throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name +
"', but there is no such placeholder");
@@ -96,7 +96,7 @@ class OperationMapper {
return new TypedTensorFunction(type, new VariableTensor(name));
}
- TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult.Signature signature) {
+ TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult result) {
if ( ! tfNode.getName().endsWith("/read"))
throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " +
"nodes are only supported when reading variables");
@@ -114,7 +114,7 @@ class OperationMapper {
throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " +
importedTensors.size());
Tensor constant = tensorConverter.toVespaTensor(importedTensors.get(0));
- signature.owner().constant(name, constant);
+ result.constant(name, constant);
return new TypedTensorFunction(constant.type(),
new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")")));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 66bfe9fcfb9..0f38ee0b0a4 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -54,7 +54,7 @@ public class TensorFlowImporter {
ImportResult.Signature signature = result.signature(signatureName);
importInputs(signatureDef.getInputsMap(), signature);
signature.output(nodeName,
- new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, signature)));
+ new RankingExpression(nodeName, importNode(nodeName, graph.getGraphDef(), model, result)));
return result;
}
catch (IOException e) {
@@ -71,7 +71,7 @@ public class TensorFlowImporter {
for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) {
String outputName = output.getKey();
try {
- ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, signature);
+ ExpressionNode node = importOutput(output.getValue(), graph.getGraphDef(), model, result);
signature.output(outputName, new RankingExpression(outputName, node));
}
catch (IllegalArgumentException e) {
@@ -105,42 +105,38 @@ public class TensorFlowImporter {
return b.build();
}
- private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model,
- ImportResult.Signature signature) {
- return importNode(nameOf(output.getName()), graph, model, signature);
+ private ExpressionNode importOutput(TensorInfo output, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ return importNode(nameOf(output.getName()), graph, model, result);
}
- private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model,
- ImportResult.Signature signature) {
- TensorFunction function = importNode(getNode(nodeName, graph), graph, model, signature).function();
+ private ExpressionNode importNode(String nodeName, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ TensorFunction function = importNode(getNode(nodeName, graph), graph, model, result).function();
return new TensorFunctionNode(function); // wrap top level (only) as an expression
}
/** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
- private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model,
- ImportResult.Signature signature) {
- return tensorFunctionOf(tfNode, graph, model, signature);
+ private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ return tensorFunctionOf(tfNode, graph, model, result);
}
- private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model,
- ImportResult.Signature signature) {
+ private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
// Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops
// TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/
switch (tfNode.getOp().toLowerCase()) {
- case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, signature), ScalarFunctions.add());
- case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, signature), ScalarFunctions.acos());
- case "placeholder" : return operationMapper.placeholder(tfNode, signature);
- case "identity" : return operationMapper.identity(tfNode, model, signature);
- case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, signature));
- case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, signature));
+ case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add());
+ case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.acos());
+ case "placeholder" : return operationMapper.placeholder(tfNode, result);
+ case "identity" : return operationMapper.identity(tfNode, model, result);
+ case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, result));
+ case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, result));
default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported");
}
}
private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model,
- ImportResult.Signature signature) {
+ ImportResult result) {
return tfNode.getInputList().stream()
- .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, signature))
+ .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result))
.collect(Collectors.toList());
}