summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <jonbratseth@yahoo.com>2018-01-11 09:58:49 +0100
committerGitHub <noreply@github.com>2018-01-11 09:58:49 +0100
commit58c0e6c1115950aa479217b9c97c74f0ebd0ec01 (patch)
treec92b33c6423026e2239a8d3d50157f3bbefb3763 /searchlib
parentbc50cf5e58d01cc547926173c480012da2a043fa (diff)
parentd479ea2ed063832297286e60e9ffb2b7f248be59 (diff)
Merge pull request #4603 from vespa-engine/bratseth/propagate-tensorflow-warnings
Bratseth/propagate tensorflow warnings
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java21
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java)22
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java15
7 files changed, 35 insertions, 39 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index bac141644c6..d5958c71454 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -86,7 +86,7 @@ class OperationMapper {
return new TypedTensorFunction(resultType, function);
}
- TypedTensorFunction placeholder(NodeDef tfNode, ImportResult result) {
+ TypedTensorFunction placeholder(NodeDef tfNode, TensorFlowModel result) {
String name = tfNode.getName();
TensorType type = result.arguments().get(name);
if (type == null)
@@ -96,7 +96,7 @@ class OperationMapper {
return new TypedTensorFunction(type, new VariableTensor(name));
}
- TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, ImportResult result) {
+ TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result) {
if ( ! tfNode.getName().endsWith("/read"))
throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " +
"nodes are only supported when reading variables");
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 4a6551adca7..e62ff4c54bf 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -34,7 +34,7 @@ public class TensorFlowImporter {
*
* @param modelDir the directory containing the TensorFlow model files to import
*/
- public ImportResult importModel(String modelDir) {
+ public TensorFlowModel importModel(String modelDir) {
try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
return importModel(model);
}
@@ -44,7 +44,7 @@ public class TensorFlowImporter {
}
/** Imports a TensorFlow model */
- public ImportResult importModel(SavedModelBundle model) {
+ public TensorFlowModel importModel(SavedModelBundle model) {
try {
return importGraph(MetaGraphDef.parseFrom(model.metaGraphDef()), model);
}
@@ -53,10 +53,10 @@ public class TensorFlowImporter {
}
}
- private ImportResult importGraph(MetaGraphDef graph, SavedModelBundle model) {
- ImportResult result = new ImportResult();
+ private TensorFlowModel importGraph(MetaGraphDef graph, SavedModelBundle model) {
+ TensorFlowModel result = new TensorFlowModel();
for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
- ImportResult.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName"
+ TensorFlowModel.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName"
importInputs(signatureEntry.getValue().getInputsMap(), signature);
for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) {
@@ -67,15 +67,14 @@ public class TensorFlowImporter {
signature.output(outputName, nameOf(output.getValue().getName()));
}
catch (IllegalArgumentException e) {
- result.warn("Skipping output '" + outputName + "' of " + signature +
- ": " + Exceptions.toMessageString(e));
+ signature.skippedOutput(outputName, Exceptions.toMessageString(e));
}
}
}
return result;
}
- private void importInputs(Map<String, TensorInfo> inputInfoMap, ImportResult.Signature signature) {
+ private void importInputs(Map<String, TensorInfo> inputInfoMap, TensorFlowModel.Signature signature) {
inputInfoMap.forEach((key, value) -> {
String argumentName = nameOf(value.getName());
TensorType argumentType = importTensorType(value.getTensorShape());
@@ -98,7 +97,7 @@ public class TensorFlowImporter {
}
/** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
- private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) {
TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result);
// We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output
// will be used
@@ -106,7 +105,7 @@ public class TensorFlowImporter {
return function;
}
- private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, ImportResult result) {
+ private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) {
// Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops
// TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/
switch (tfNode.getOp().toLowerCase()) {
@@ -121,7 +120,7 @@ public class TensorFlowImporter {
}
private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model,
- ImportResult result) {
+ TensorFlowModel result) {
return tfNode.getInputList().stream()
.map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result))
.collect(Collectors.toList());
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
index 947e6d7a5e1..6740b128d6b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
@@ -4,12 +4,9 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
-import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
-import java.util.List;
import java.util.Map;
-import java.util.stream.Collectors;
/**
* The result of importing a TensorFlow model into Vespa.
@@ -20,18 +17,16 @@ import java.util.stream.Collectors;
* @author bratseth
*/
// This object can be built incrementally within this package, but is immutable when observed from outside the package
-public class ImportResult {
+public class TensorFlowModel {
private final Map<String, Signature> signatures = new HashMap<>();
private final Map<String, TensorType> arguments = new HashMap<>();
private final Map<String, Tensor> constants = new HashMap<>();
private final Map<String, RankingExpression> expressions = new HashMap<>();
- private final List<String> warnings = new ArrayList<>();
void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
void constant(String name, Tensor constant) { constants.put(name, constant); }
void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
- void warn(String warning) { warnings.add(warning); }
/** Returns the given signature. If it does not already exist it is added to this. */
Signature signature(String name) {
@@ -51,11 +46,6 @@ public class ImportResult {
*/
public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
- /** Returns an immutable list, in natural sort order of the warnings generated while importing this */
- public List<String> warnings() {
- return warnings.stream().sorted().collect(Collectors.toList());
- }
-
/** Returns an immutable map of the signatures of this */
public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
@@ -68,6 +58,7 @@ public class ImportResult {
private final String name;
private final Map<String, String> inputs = new HashMap<>();
private final Map<String, String> outputs = new HashMap<>();
+ private final Map<String, String> skippedOutputs = new HashMap<>();
Signature(String name) {
this.name = name;
@@ -75,9 +66,10 @@ public class ImportResult {
void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
void output(String name, String expressionName) { outputs.put(name, expressionName); }
+ void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
/** Returns the result this is part of */
- ImportResult owner() { return ImportResult.this; }
+ TensorFlowModel owner() { return TensorFlowModel.this; }
/**
* Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
@@ -91,6 +83,12 @@ public class ImportResult {
/** Returns an immutable list of the expression names of this */
public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); }
+ /**
+ * Returns an immutable list of the outputs of this which could not be imported,
+ * with a string detailing the reason for each
+ */
+ public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); }
+
/** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */
public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
index 1b8239ba643..216b677f6ff 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
@@ -16,7 +16,7 @@ import java.util.List;
*
* @author bratseth
*/
-public class ConstantDereferencer extends ExpressionTransformer {
+public class ConstantDereferencer extends ExpressionTransformer<TransformContext> {
@Override
public ExpressionNode transform(ExpressionNode node, TransformContext context) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java
index c585c0dea1f..e5d0b4671c0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java
@@ -13,20 +13,20 @@ import java.util.List;
*
* @author bratseth
*/
-public abstract class ExpressionTransformer {
+public abstract class ExpressionTransformer<CONTEXT extends TransformContext> {
- public RankingExpression transform(RankingExpression expression, TransformContext context) {
+ public RankingExpression transform(RankingExpression expression, CONTEXT context) {
return new RankingExpression(expression.getName(), transform(expression.getRoot(), context));
}
/** Transforms an expression node and returns the transformed node */
- public abstract ExpressionNode transform(ExpressionNode node, TransformContext context);
+ public abstract ExpressionNode transform(ExpressionNode node, CONTEXT context);
/**
* Utility method which calls transform on each child of the given node and return the resulting transformed
* composite
*/
- protected CompositeNode transformChildren(CompositeNode node, TransformContext context) {
+ protected CompositeNode transformChildren(CompositeNode node, CONTEXT context) {
List<ExpressionNode> children = node.children();
List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
for (ExpressionNode child : children)
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
index 9e8491340b0..e8e2fdf2454 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
@@ -20,7 +20,7 @@ import java.util.List;
*
* @author bratseth
*/
-public class Simplifier extends ExpressionTransformer {
+public class Simplifier extends ExpressionTransformer<TransformContext> {
@Override
public ExpressionNode transform(ExpressionNode node, TransformContext context) {
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
index 0370fc7fc94..6536a2f2a06 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
@@ -28,11 +28,7 @@ public class Mnist_SoftmaxTestCase {
public void testImporting() {
String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved";
SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
- ImportResult result = new TensorFlowImporter().importModel(model);
-
- // Check logged messages
- result.warnings().forEach(System.err::println);
- assertEquals(0, result.warnings().size());
+ TensorFlowModel result = new TensorFlowImporter().importModel(model);
// Check constants
assertEquals(2, result.constants().size());
@@ -51,7 +47,7 @@ public class Mnist_SoftmaxTestCase {
// Check signatures
assertEquals(1, result.signatures().size());
- ImportResult.Signature signature = result.signatures().get("serving_default");
+ TensorFlowModel.Signature signature = result.signatures().get("serving_default");
assertNotNull(signature);
// ... signature inputs
@@ -71,6 +67,9 @@ public class Mnist_SoftmaxTestCase {
"f(a,b)(a + b))",
toNonPrimitiveString(output));
+ // ... skipped outputs
+ assertEquals(0, signature.skippedOutputs().size());
+
// Test execution
assertEqualResult(model, result, "Variable/read");
assertEqualResult(model, result, "Variable_1/read");
@@ -78,7 +77,7 @@ public class Mnist_SoftmaxTestCase {
assertEqualResult(model, result, "add");
}
- private void assertEqualResult(SavedModelBundle model, ImportResult result, String operationName) {
+ private void assertEqualResult(SavedModelBundle model, TensorFlowModel result, String operationName) {
Tensor tfResult = tensorFlowExecute(model, operationName);
Context context = contextFrom(result);
Tensor placeholder = placeholderArgument();
@@ -96,7 +95,7 @@ public class Mnist_SoftmaxTestCase {
return new TensorConverter().toVespaTensor(results.get(0));
}
- private Context contextFrom(ImportResult result) {
+ private Context contextFrom(TensorFlowModel result) {
MapContext context = new MapContext();
result.constants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
return context;