summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2018-01-10 12:34:10 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2018-01-10 12:34:10 +0100
commitfc3e000141fa6bca30231bc49ce472b01759304d (patch)
treef966bcad948a3ece4fe13fc2d074f3674203c29e /searchlib
parent217674665ab15f15bbda2a2d2b49a3858ca1b319 (diff)
Store warnings under the right output
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java18
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java3
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java7
3 files changed, 12 insertions, 16 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
index 947e6d7a5e1..03c0d87fdd0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
@@ -4,12 +4,9 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
-import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
-import java.util.List;
import java.util.Map;
-import java.util.stream.Collectors;
/**
* The result of importing a TensorFlow model into Vespa.
@@ -26,12 +23,10 @@ public class ImportResult {
private final Map<String, TensorType> arguments = new HashMap<>();
private final Map<String, Tensor> constants = new HashMap<>();
private final Map<String, RankingExpression> expressions = new HashMap<>();
- private final List<String> warnings = new ArrayList<>();
void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
void constant(String name, Tensor constant) { constants.put(name, constant); }
void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
- void warn(String warning) { warnings.add(warning); }
/** Returns the given signature. If it does not already exist it is added to this. */
Signature signature(String name) {
@@ -51,11 +46,6 @@ public class ImportResult {
*/
public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
- /** Returns an immutable list, in natural sort order of the warnings generated while importing this */
- public List<String> warnings() {
- return warnings.stream().sorted().collect(Collectors.toList());
- }
-
/** Returns an immutable map of the signatures of this */
public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
@@ -68,6 +58,7 @@ public class ImportResult {
private final String name;
private final Map<String, String> inputs = new HashMap<>();
private final Map<String, String> outputs = new HashMap<>();
+ private final Map<String, String> skippedOutputs = new HashMap<>();
Signature(String name) {
this.name = name;
@@ -75,6 +66,7 @@ public class ImportResult {
void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
void output(String name, String expressionName) { outputs.put(name, expressionName); }
+ void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
/** Returns the result this is part of */
ImportResult owner() { return ImportResult.this; }
@@ -91,6 +83,12 @@ public class ImportResult {
/** Returns an immutable list of the expression names of this */
public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); }
+ /**
+ * Returns an immutable list of the outputs of this which could not be imported,
+ * with a string detailing the reason for each
+ */
+ public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); }
+
/** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */
public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 4a6551adca7..69781fa915c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -67,8 +67,7 @@ public class TensorFlowImporter {
signature.output(outputName, nameOf(output.getValue().getName()));
}
catch (IllegalArgumentException e) {
- result.warn("Skipping output '" + outputName + "' of " + signature +
- ": " + Exceptions.toMessageString(e));
+ signature.skippedOutput(outputName, Exceptions.toMessageString(e));
}
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
index 0370fc7fc94..8efeceaa034 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
@@ -30,10 +30,6 @@ public class Mnist_SoftmaxTestCase {
SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
ImportResult result = new TensorFlowImporter().importModel(model);
- // Check logged messages
- result.warnings().forEach(System.err::println);
- assertEquals(0, result.warnings().size());
-
// Check constants
assertEquals(2, result.constants().size());
@@ -71,6 +67,9 @@ public class Mnist_SoftmaxTestCase {
"f(a,b)(a + b))",
toNonPrimitiveString(output));
+ // ... skipped outputs
+ assertEquals(0, signature.skippedOutputs().size());
+
// Test execution
assertEqualResult(model, result, "Variable/read");
assertEqualResult(model, result, "Variable_1/read");