diff options
author | Jon Bratseth <jonbratseth@yahoo.com> | 2018-01-11 09:58:49 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-01-11 09:58:49 +0100 |
commit | 58c0e6c1115950aa479217b9c97c74f0ebd0ec01 (patch) | |
tree | c92b33c6423026e2239a8d3d50157f3bbefb3763 /searchlib | |
parent | bc50cf5e58d01cc547926173c480012da2a043fa (diff) | |
parent | d479ea2ed063832297286e60e9ffb2b7f248be59 (diff) |
Merge pull request #4603 from vespa-engine/bratseth/propagate-tensorflow-warnings
Bratseth/propagate tensorflow warnings
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java | 4 | ||||
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java | 21 | ||||
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java) | 22 | ||||
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java | 2 | ||||
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java | 8 | ||||
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java | 2 | ||||
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java | 15 |
7 files changed, 35 insertions, 39 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index bac141644c6..d5958c71454 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -86,7 +86,7 @@ class OperationMapper { return new TypedTensorFunction(resultType, function); } - TypedTensorFunction placeholder(NodeDef tfNode, ImportResult result) { + TypedTensorFunction placeholder(NodeDef tfNode, TensorFlowModel result) { String name = tfNode.getName(); TensorType type = result.arguments().get(name); if (type == null) @@ -96,7 +96,7 @@ class OperationMapper { return new TypedTensorFunction(type, new VariableTensor(name)); } - TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult result) { + TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result) { if ( ! tfNode.getName().endsWith("/read")) throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " + "nodes are only supported when reading variables"); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 4a6551adca7..e62ff4c54bf 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -34,7 +34,7 @@ public class TensorFlowImporter { * * @param modelDir the directory containing the TensorFlow model files to import */ - public ImportResult importModel(String modelDir) { + public TensorFlowModel importModel(String modelDir) { try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { return importModel(model); } @@ -44,7 +44,7 @@ public class TensorFlowImporter { } /** Imports a TensorFlow model */ - public ImportResult importModel(SavedModelBundle model) { + public TensorFlowModel importModel(SavedModelBundle model) { try { return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model); } @@ -53,10 +53,10 @@ public class TensorFlowImporter { } } - private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) { - ImportResult result = new ImportResult(); + private TensorFlowModel importGraph(MetaGraphDef graph, SavedModelBundle model) { + TensorFlowModel result = new TensorFlowModel(); for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) { - ImportResult.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName" + TensorFlowModel.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName" importInputs(signatureEntry.getValue().getInputsMap(), signature); for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) { @@ -67,15 +67,14 @@ public class TensorFlowImporter { signature.output(outputName, nameOf(output.getValue().getName())); } catch (IllegalArgumentException e) { - result.warn("Skipping output '" + outputName + "' of " + signature + - ": " + Exceptions.toMessageString(e)); + signature.skippedOutput(outputName, Exceptions.toMessageString(e)); } } } return result; } - private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult.Signature signature) { + private void importInputs(Map<String, TensorInfo> inputInfoMap, TensorFlowModel.Signature signature) { inputInfoMap.forEach((key, value) -> { String argumentName = nameOf(value.getName()); TensorType argumentType = importTensorType(value.getTensorShape()); @@ -98,7 +97,7 @@ public class TensorFlowImporter { } /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ - private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { + private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) { TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result); // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output // will be used @@ -106,7 +105,7 @@ public class TensorFlowImporter { return function; } - private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) { + private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) { // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/ switch (tfNode.getOp().toLowerCase()) { @@ -121,7 +120,7 @@ public class TensorFlowImporter { } private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, - ImportResult result) { + TensorFlowModel result) { return tfNode.getInputList().stream() .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result)) .collect(Collectors.toList()); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java index 947e6d7a5e1..6740b128d6b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java @@ -4,12 +4,9 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; -import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.stream.Collectors; /** * The result of importing a TensorFlow model into Vespa. @@ -20,18 +17,16 @@ import java.util.stream.Collectors; * @author bratseth */ // This object can be built incrementally within this package, but is immutable when observed from outside the package -public class ImportResult { +public class TensorFlowModel { private final Map<String, Signature> signatures = new HashMap<>(); private final Map<String, TensorType> arguments = new HashMap<>(); private final Map<String, Tensor> constants = new HashMap<>(); private final Map<String, RankingExpression> expressions = new HashMap<>(); - private final List<String> warnings = new ArrayList<>(); void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } void constant(String name, Tensor constant) { constants.put(name, constant); } void expression(String name, RankingExpression expression) { expressions.put(name, expression); } - void warn(String warning) { warnings.add(warning); } /** Returns the given signature. If it does not already exist it is added to this. */ Signature signature(String name) { @@ -51,11 +46,6 @@ public class ImportResult { */ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } - /** Returns an immutable list, in natural sort order of the warnings generated while importing this */ - public List<String> warnings() { - return warnings.stream().sorted().collect(Collectors.toList()); - } - /** Returns an immutable map of the signatures of this */ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } @@ -68,6 +58,7 @@ public class ImportResult { private final String name; private final Map<String, String> inputs = new HashMap<>(); private final Map<String, String> outputs = new HashMap<>(); + private final Map<String, String> skippedOutputs = new HashMap<>(); Signature(String name) { this.name = name; @@ -75,9 +66,10 @@ public class ImportResult { void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } void output(String name, String expressionName) { outputs.put(name, expressionName); } + void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } /** Returns the result this is part of */ - ImportResult owner() { return ImportResult.this; } + TensorFlowModel owner() { return TensorFlowModel.this; } /** * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name @@ -91,6 +83,12 @@ public class ImportResult { /** Returns an immutable list of the expression names of this */ public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); } + /** + * Returns an immutable list of the outputs of this which could not be imported, + * with a string detailing the reason for each + */ + public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); } + /** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */ public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java index 1b8239ba643..216b677f6ff 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java @@ -16,7 +16,7 @@ import java.util.List; * * @author bratseth */ -public class ConstantDereferencer extends ExpressionTransformer { +public class ConstantDereferencer extends ExpressionTransformer<TransformContext> { @Override public ExpressionNode transform(ExpressionNode node, TransformContext context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java index c585c0dea1f..e5d0b4671c0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java @@ -13,20 +13,20 @@ import java.util.List; * * @author bratseth */ -public abstract class ExpressionTransformer { +public abstract class ExpressionTransformer<CONTEXT extends TransformContext> { - public RankingExpression transform(RankingExpression expression, TransformContext context) { + public RankingExpression transform(RankingExpression expression, CONTEXT context) { return new RankingExpression(expression.getName(), transform(expression.getRoot(), context)); } /** Transforms an expression node and returns the transformed node */ - public abstract ExpressionNode transform(ExpressionNode node, TransformContext context); + public abstract ExpressionNode transform(ExpressionNode node, CONTEXT context); /** * Utility method which calls transform on each child of the given node and return the resulting transformed * composite */ - protected CompositeNode transformChildren(CompositeNode node, TransformContext context) { + protected CompositeNode transformChildren(CompositeNode node, CONTEXT context) { List<ExpressionNode> children = node.children(); List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); for (ExpressionNode child : children) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java index 9e8491340b0..e8e2fdf2454 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java @@ -20,7 +20,7 @@ import java.util.List; * * @author bratseth */ -public class Simplifier extends ExpressionTransformer { +public class Simplifier extends ExpressionTransformer<TransformContext> { @Override public ExpressionNode transform(ExpressionNode node, TransformContext context) { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java index 0370fc7fc94..6536a2f2a06 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java @@ -28,11 +28,7 @@ public class Mnist_SoftmaxTestCase { public void testImporting() { String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved"; SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); - ImportResult result = new TensorFlowImporter().importModel(model); - - // Check logged messages - result.warnings().forEach(System.err::println); - assertEquals(0, result.warnings().size()); + TensorFlowModel result = new TensorFlowImporter().importModel(model); // Check constants assertEquals(2, result.constants().size()); @@ -51,7 +47,7 @@ public class Mnist_SoftmaxTestCase { // Check signatures assertEquals(1, result.signatures().size()); - ImportResult.Signature signature = result.signatures().get("serving_default"); + TensorFlowModel.Signature signature = result.signatures().get("serving_default"); assertNotNull(signature); // ... signature inputs @@ -71,6 +67,9 @@ public class Mnist_SoftmaxTestCase { "f(a,b)(a + b))", toNonPrimitiveString(output)); + // ... skipped outputs + assertEquals(0, signature.skippedOutputs().size()); + // Test execution assertEqualResult(model, result, "Variable/read"); assertEqualResult(model, result, "Variable_1/read"); @@ -78,7 +77,7 @@ public class Mnist_SoftmaxTestCase { assertEqualResult(model, result, "add"); } - private void assertEqualResult(SavedModelBundle model, ImportResult result, String operationName) { + private void assertEqualResult(SavedModelBundle model, TensorFlowModel result, String operationName) { Tensor tfResult = tensorFlowExecute(model, operationName); Context context = contextFrom(result); Tensor placeholder = placeholderArgument(); @@ -96,7 +95,7 @@ public class Mnist_SoftmaxTestCase { return new TensorConverter().toVespaTensor(results.get(0)); } - private Context contextFrom(ImportResult result) { + private Context contextFrom(TensorFlowModel result) { MapContext context = new MapContext(); result.constants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); return context; |