diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-10 12:34:10 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-10 12:34:10 +0100 |
commit | fc3e000141fa6bca30231bc49ce472b01759304d (patch) | |
tree | f966bcad948a3ece4fe13fc2d074f3674203c29e /searchlib | |
parent | 217674665ab15f15bbda2a2d2b49a3858ca1b319 (diff) |
Store warnings under the right output
Diffstat (limited to 'searchlib')
3 files changed, 12 insertions, 16 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java index 947e6d7a5e1..03c0d87fdd0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java @@ -4,12 +4,9 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; -import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.stream.Collectors; /** * The result of importing a TensorFlow model into Vespa. @@ -26,12 +23,10 @@ public class ImportResult { private final Map<String, TensorType> arguments = new HashMap<>(); private final Map<String, Tensor> constants = new HashMap<>(); private final Map<String, RankingExpression> expressions = new HashMap<>(); - private final List<String> warnings = new ArrayList<>(); void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } void constant(String name, Tensor constant) { constants.put(name, constant); } void expression(String name, RankingExpression expression) { expressions.put(name, expression); } - void warn(String warning) { warnings.add(warning); } /** Returns the given signature. If it does not already exist it is added to this. */ Signature signature(String name) { @@ -51,11 +46,6 @@ public class ImportResult { */ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } - /** Returns an immutable list, in natural sort order of the warnings generated while importing this */ - public List<String> warnings() { - return warnings.stream().sorted().collect(Collectors.toList()); - } - /** Returns an immutable map of the signatures of this */ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } @@ -68,6 +58,7 @@ public class ImportResult { private final String name; private final Map<String, String> inputs = new HashMap<>(); private final Map<String, String> outputs = new HashMap<>(); + private final Map<String, String> skippedOutputs = new HashMap<>(); Signature(String name) { this.name = name; @@ -75,6 +66,7 @@ public class ImportResult { void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } void output(String name, String expressionName) { outputs.put(name, expressionName); } + void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } /** Returns the result this is part of */ ImportResult owner() { return ImportResult.this; } @@ -91,6 +83,12 @@ public class ImportResult { /** Returns an immutable list of the expression names of this */ public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); } + /** + * Returns an immutable list of the outputs of this which could not be imported, + * with a string detailing the reason for each + */ + public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); } + /** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */ public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 4a6551adca7..69781fa915c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -67,8 +67,7 @@ public class TensorFlowImporter { signature.output(outputName, nameOf(output.getValue().getName())); } catch (IllegalArgumentException e) { - result.warn("Skipping output '" + outputName + "' of " + signature + - ": " + Exceptions.toMessageString(e)); + signature.skippedOutput(outputName, Exceptions.toMessageString(e)); } } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java index 0370fc7fc94..8efeceaa034 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java @@ -30,10 +30,6 @@ public class Mnist_SoftmaxTestCase { SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); ImportResult result = new TensorFlowImporter().importModel(model); - // Check logged messages - result.warnings().forEach(System.err::println); - assertEquals(0, result.warnings().size()); - // Check constants assertEquals(2, result.constants().size()); @@ -71,6 +67,9 @@ public class Mnist_SoftmaxTestCase { "f(a,b)(a + b))", toNonPrimitiveString(output)); + // ... skipped outputs + assertEquals(0, signature.skippedOutputs().size()); + // Test execution assertEqualResult(model, result, "Variable/read"); assertEqualResult(model, result, "Variable_1/read"); |