diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-10 12:50:08 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-10 12:50:08 +0100 |
commit | d479ea2ed063832297286e60e9ffb2b7f248be59 (patch) | |
tree | e1d27ccc652d2be336159fafb37ea94d14cb1d2e /searchlib/src/main | |
parent | 9ef5fd6f9edf47d48c34cd6a8623ac38daa933f5 (diff) |
ImportResult -> TensorFlowModel
Diffstat (limited to 'searchlib/src/main')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java | 4 | ||||
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java | 18 | ||||
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java) | 4 |
3 files changed, 13 insertions, 13 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index bac141644c6..d5958c71454 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -86,7 +86,7 @@ class OperationMapper { return new TypedTensorFunction(resultType, function); } - TypedTensorFunction placeholder(NodeDef tfNode, ImportResult result) { + TypedTensorFunction placeholder(NodeDef tfNode, TensorFlowModel result) { String name = tfNode.getName(); TensorType type = result.arguments().get(name); if (type == null) @@ -96,7 +96,7 @@ class OperationMapper { return new TypedTensorFunction(type, new VariableTensor(name)); } - TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult result) { + TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result) { if ( ! tfNode.getName().endsWith("/read")) throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " + "nodes are only supported when reading variables"); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 69781fa915c..e62ff4c54bf 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -34,7 +34,7 @@ public class TensorFlowImporter { * * @param modelDir the directory containing the TensorFlow model files to import */ - public ImportResult importModel(String modelDir) { + public TensorFlowModel importModel(String modelDir) { try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { return importModel(model); } @@ -44,7 +44,7 @@ public class TensorFlowImporter { } /** Imports a TensorFlow model */ - public ImportResult importModel(SavedModelBundle model) { + public TensorFlowModel importModel(SavedModelBundle model) { try { return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model); } @@ -53,10 +53,10 @@ public class TensorFlowImporter { } } - private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) { - ImportResult result = new ImportResult(); + private TensorFlowModel importGraph(MetaGraphDef graph, SavedModelBundle model) { + TensorFlowModel result = new TensorFlowModel(); for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) { - ImportResult.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName" + TensorFlowModel.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName" importInputs(signatureEntry.getValue().getInputsMap(), signature); for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) { @@ -74,7 +74,7 @@ public class TensorFlowImporter { return result; } - private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult.Signature signature) { + private void importInputs(Map<String, TensorInfo> inputInfoMap, TensorFlowModel.Signature signature) { inputInfoMap.forEach((key, value) -> { String argumentName = nameOf(value.getName()); TensorType argumentType = importTensorType(value.getTensorShape()); @@ -97,7 +97,7 @@ public class TensorFlowImporter { } /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ - private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { + private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) { TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result); // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output // will be used @@ -105,7 +105,7 @@ public class TensorFlowImporter { return function; } - private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { + private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) { // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/ switch (tfNode.getOp().toLowerCase()) { @@ -120,7 +120,7 @@ public class TensorFlowImporter { } private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, - ImportResult result) { + TensorFlowModel result) { return tfNode.getInputList().stream() .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result)) .collect(Collectors.toList()); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java index 03c0d87fdd0..6740b128d6b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java @@ -17,7 +17,7 @@ import java.util.Map; * @author bratseth */ // This object can be built incrementally within this package, but is immutable when observed from outside the package -public class ImportResult { +public class TensorFlowModel { private final Map<String, Signature> signatures = new HashMap<>(); private final Map<String, TensorType> arguments = new HashMap<>(); @@ -69,7 +69,7 @@ public class ImportResult { void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } /** Returns the result this is part of */ - ImportResult owner() { return ImportResult.this; } + TensorFlowModel owner() { return TensorFlowModel.this; } /** * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name |