aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2018-01-10 12:50:08 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2018-01-10 12:50:08 +0100
commitd479ea2ed063832297286e60e9ffb2b7f248be59 (patch)
treee1d27ccc652d2be336159fafb37ea94d14cb1d2e /searchlib/src/main/java
parent9ef5fd6f9edf47d48c34cd6a8623ac38daa933f5 (diff)
ImportResult -> TensorFlowModel
Diffstat (limited to 'searchlib/src/main/java')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java18
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java)4
3 files changed, 13 insertions, 13 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index bac141644c6..d5958c71454 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -86,7 +86,7 @@ class OperationMapper {
return new TypedTensorFunction(resultType, function);
}
- TypedTensorFunction placeholder(NodeDef tfNode, ImportResult result) {
+ TypedTensorFunction placeholder(NodeDef tfNode, TensorFlowModel result) {
String name = tfNode.getName();
TensorType type = result.arguments().get(name);
if (type == null)
@@ -96,7 +96,7 @@ class OperationMapper {
return new TypedTensorFunction(type, new VariableTensor(name));
}
- TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult result) {
+ TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result) {
if ( ! tfNode.getName().endsWith("/read"))
throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " +
"nodes are only supported when reading variables");
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 69781fa915c..e62ff4c54bf 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -34,7 +34,7 @@ public class TensorFlowImporter {
*
* @param modelDir the directory containing the TensorFlow model files to import
*/
- public ImportResult importModel(String modelDir) {
+ public TensorFlowModel importModel(String modelDir) {
try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
return importModel(model);
}
@@ -44,7 +44,7 @@ public class TensorFlowImporter {
}
/** Imports a TensorFlow model */
- public ImportResult importModel(SavedModelBundle model) {
+ public TensorFlowModel importModel(SavedModelBundle model) {
try {
return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model);
}
@@ -53,10 +53,10 @@ public class TensorFlowImporter {
}
}
- private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) {
- ImportResult result = new ImportResult();
+ private TensorFlowModel importGraph(MetaGraphDef graph, SavedModelBundle model) {
+ TensorFlowModel result = new TensorFlowModel();
for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
- ImportResult.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName"
+ TensorFlowModel.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName"
importInputs(signatureEntry.getValue().getInputsMap(), signature);
for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) {
@@ -74,7 +74,7 @@ public class TensorFlowImporter {
return result;
}
- private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult.Signature signature) {
+ private void importInputs(Map<String, TensorInfo> inputInfoMap, TensorFlowModel.Signature signature) {
inputInfoMap.forEach((key, value) -> {
String argumentName = nameOf(value.getName());
TensorType argumentType = importTensorType(value.getTensorShape());
@@ -97,7 +97,7 @@ public class TensorFlowImporter {
}
/** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
- private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) {
TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result);
// We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output
// will be used
@@ -105,7 +105,7 @@ public class TensorFlowImporter {
return function;
}
- private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) {
// Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops
// TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/
switch (tfNode.getOp().toLowerCase()) {
@@ -120,7 +120,7 @@ public class TensorFlowImporter {
}
private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model,
- ImportResult result) {
+ TensorFlowModel result) {
return tfNode.getInputList().stream()
.map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result))
.collect(Collectors.toList());
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
index 03c0d87fdd0..6740b128d6b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
@@ -17,7 +17,7 @@ import java.util.Map;
* @author bratseth
*/
// This object can be built incrementally within this package, but is immutable when observed from outside the package
-public class ImportResult {
+public class TensorFlowModel {
private final Map<String, Signature> signatures = new HashMap<>();
private final Map<String, TensorType> arguments = new HashMap<>();
@@ -69,7 +69,7 @@ public class ImportResult {
void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
/** Returns the result this is part of */
- ImportResult owner() { return ImportResult.this; }
+ TensorFlowModel owner() { return TensorFlowModel.this; }
/**
* Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name