aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2018-01-10 12:34:10 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2018-01-10 12:34:10 +0100
commitfc3e000141fa6bca30231bc49ce472b01759304d (patch)
treef966bcad948a3ece4fe13fc2d074f3674203c29e
parent217674665ab15f15bbda2a2d2b49a3858ca1b319 (diff)
Store warnings under the right output
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java68
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java18
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java3
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java7
5 files changed, 64 insertions, 40 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index fdc90d2334e..fb4f5b0a5a9 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -1,5 +1,7 @@
package com.yahoo.searchdefinition.expressiontransforms;
+import com.google.common.base.Joiner;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.ImportResult;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
@@ -50,40 +52,66 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
ImportResult result = importedModels.computeIfAbsent(modelPath, k -> tensorFlowImporter.importModel(modelPath));
// Find the specified expression
- ImportResult.Signature signature = chooseOrDefault("signatures", result.signatures(),
+ ImportResult.Signature signature = chooseSignature(result,
optionalArgument(1, feature.getArguments()));
- String output = chooseOrDefault("outputs", signature.outputs(),
- optionalArgument(2, feature.getArguments()));
+ RankingExpression expression = chooseOutput(signature,
+ optionalArgument(2, feature.getArguments()));
- // Add all constants
+ // Add all constants (after finding outputs to fail faster when the output is not found)
result.constants().forEach((k, v) -> context.rankProfile().addConstantTensor(k, new TensorValue(v)));
- return result.expressions().get(output).getRoot();
+ return expression.getRoot();
}
catch (IllegalArgumentException e) {
- throw new IllegalArgumentException("Could not import tensorflow model from " + feature, e);
+ throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e);
}
}
/**
- * Returns the specified, existing map value, or the only map value if no key is specified.
+ * Returns the specified, existing signature, or the only signature if none is specified.
* Throws IllegalArgumentException in all other cases.
*/
- private <T> T chooseOrDefault(String valueDescription, Map<String, T> map, Optional<String> key) {
- if ( ! key.isPresent()) {
- if (map.size() == 0)
- throw new IllegalArgumentException("No " + valueDescription + " are present");
- if (map.size() > 1)
- throw new IllegalArgumentException("Model has multiple " + valueDescription + ", but no " +
- valueDescription + " argument is specified");
- return map.values().stream().findFirst().get();
+ private ImportResult.Signature chooseSignature(ImportResult importResult, Optional<String> signatureName) {
+ if ( ! signatureName.isPresent()) {
+ if (importResult.signatures().size() == 0)
+ throw new IllegalArgumentException("No signatures are present");
+ if (importResult.signatures().size() > 1)
+ throw new IllegalArgumentException("Model has multiple signatures (" +
+ Joiner.on(", ").join(importResult.signatures().keySet()) +
+ "), one must be specified " +
+ "as a second argument to tensorflow()");
+ return importResult.signatures().values().stream().findFirst().get();
}
else {
- T value = map.get(key.get());
- if (value == null)
- throw new IllegalArgumentException("Model does not have the specified " +
- valueDescription + " '" + key.get() + "'");
- return value;
+ ImportResult.Signature signature = importResult.signatures().get(signatureName.get());
+ if (signature == null)
+ throw new IllegalArgumentException("Model does not have the specified signature '" +
+ signatureName.get() + "'");
+ return signature;
+ }
+ }
+
+ /**
+ * Returns the specified, existing output expression, or the only output expression if no output name is specified.
+ * Throws IllegalArgumentException in all other cases.
+ */
+ private RankingExpression chooseOutput(ImportResult.Signature signature, Optional<String> outputName) {
+ if ( ! outputName.isPresent()) {
+ if (signature.outputs().size() == 0)
+ throw new IllegalArgumentException("No signatures are present");
+ if (signature.outputs().size() > 1)
+ throw new IllegalArgumentException(signature + " has multiple outputs (" +
+ Joiner.on(", ").join(signature.outputs().keySet()) +
+ "), one must be specified " +
+ "as a third argument to tensorflow()");
+ return signature.outputExpression(signature.outputs().keySet().stream().findFirst().get());
+ }
+ else {
+ RankingExpression expression = signature.outputExpression(outputName.get());
+ if (expression == null)
+ throw new IllegalArgumentException("Model does not have the specified output '" +
+ outputName.get() + "'");
+ return expression;
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 5ad85ac872c..8fcd821adfd 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -91,8 +91,8 @@ public class RankingExpressionWithTensorFlowTestCase {
fail("Expecting exception");
}
catch (IllegalArgumentException expected) {
- assertEquals("Rank profile 'my_profile' is invalid: Could not import tensorflow model from tensorflow('" +
- modelDirectory + "','serving_defaultz'): Model does not have the specified signatures 'serving_defaultz'",
+ assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from tensorflow('" +
+ modelDirectory + "','serving_defaultz'): Model does not have the specified signature 'serving_defaultz'",
Exceptions.toMessageString(expected));
}
}
@@ -110,8 +110,8 @@ public class RankingExpressionWithTensorFlowTestCase {
fail("Expecting exception");
}
catch (IllegalArgumentException expected) {
- assertEquals("Rank profile 'my_profile' is invalid: Could not import tensorflow model from tensorflow('" +
- modelDirectory + "','serving_default','x'): Model does not have the specified outputs 'x'",
+ assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from tensorflow('" +
+ modelDirectory + "','serving_default','x'): Model does not have the specified output 'x'",
Exceptions.toMessageString(expected));
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
index 947e6d7a5e1..03c0d87fdd0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
@@ -4,12 +4,9 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
-import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
-import java.util.List;
import java.util.Map;
-import java.util.stream.Collectors;
/**
* The result of importing a TensorFlow model into Vespa.
@@ -26,12 +23,10 @@ public class ImportResult {
private final Map<String, TensorType> arguments = new HashMap<>();
private final Map<String, Tensor> constants = new HashMap<>();
private final Map<String, RankingExpression> expressions = new HashMap<>();
- private final List<String> warnings = new ArrayList<>();
void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
void constant(String name, Tensor constant) { constants.put(name, constant); }
void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
- void warn(String warning) { warnings.add(warning); }
/** Returns the given signature. If it does not already exist it is added to this. */
Signature signature(String name) {
@@ -51,11 +46,6 @@ public class ImportResult {
*/
public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
- /** Returns an immutable list, in natural sort order of the warnings generated while importing this */
- public List<String> warnings() {
- return warnings.stream().sorted().collect(Collectors.toList());
- }
-
/** Returns an immutable map of the signatures of this */
public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
@@ -68,6 +58,7 @@ public class ImportResult {
private final String name;
private final Map<String, String> inputs = new HashMap<>();
private final Map<String, String> outputs = new HashMap<>();
+ private final Map<String, String> skippedOutputs = new HashMap<>();
Signature(String name) {
this.name = name;
@@ -75,6 +66,7 @@ public class ImportResult {
void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
void output(String name, String expressionName) { outputs.put(name, expressionName); }
+ void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
/** Returns the result this is part of */
ImportResult owner() { return ImportResult.this; }
@@ -91,6 +83,12 @@ public class ImportResult {
/** Returns an immutable list of the expression names of this */
public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); }
+ /**
+ * Returns an immutable list of the outputs of this which could not be imported,
+ * with a string detailing the reason for each
+ */
+ public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); }
+
/** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */
public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 4a6551adca7..69781fa915c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -67,8 +67,7 @@ public class TensorFlowImporter {
signature.output(outputName, nameOf(output.getValue().getName()));
}
catch (IllegalArgumentException e) {
- result.warn("Skipping output '" + outputName + "' of " + signature +
- ": " + Exceptions.toMessageString(e));
+ signature.skippedOutput(outputName, Exceptions.toMessageString(e));
}
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
index 0370fc7fc94..8efeceaa034 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
@@ -30,10 +30,6 @@ public class Mnist_SoftmaxTestCase {
SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
ImportResult result = new TensorFlowImporter().importModel(model);
- // Check logged messages
- result.warnings().forEach(System.err::println);
- assertEquals(0, result.warnings().size());
-
// Check constants
assertEquals(2, result.constants().size());
@@ -71,6 +67,9 @@ public class Mnist_SoftmaxTestCase {
"f(a,b)(a + b))",
toNonPrimitiveString(output));
+ // ... skipped outputs
+ assertEquals(0, signature.skippedOutputs().size());
+
// Test execution
assertEqualResult(model, result, "Variable/read");
assertEqualResult(model, result, "Variable_1/read");