diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-10 12:34:10 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-10 12:34:10 +0100 |
commit | fc3e000141fa6bca30231bc49ce472b01759304d (patch) | |
tree | f966bcad948a3ece4fe13fc2d074f3674203c29e | |
parent | 217674665ab15f15bbda2a2d2b49a3858ca1b319 (diff) |
Store warnings under the right output
5 files changed, 64 insertions, 40 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index fdc90d2334e..fb4f5b0a5a9 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -1,5 +1,7 @@ package com.yahoo.searchdefinition.expressiontransforms; +import com.google.common.base.Joiner; +import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.ImportResult; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; @@ -50,40 +52,66 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil ImportResult result = importedModels.computeIfAbsent(modelPath, k -> tensorFlowImporter.importModel(modelPath)); // Find the specified expression - ImportResult.Signature signature = chooseOrDefault("signatures", result.signatures(), + ImportResult.Signature signature = chooseSignature(result, optionalArgument(1, feature.getArguments())); - String output = chooseOrDefault("outputs", signature.outputs(), - optionalArgument(2, feature.getArguments())); + RankingExpression expression = chooseOutput(signature, + optionalArgument(2, feature.getArguments())); - // Add all constants + // Add all constants (after finding outputs to fail faster when the output is not found) result.constants().forEach((k, v) -> context.rankProfile().addConstantTensor(k, new TensorValue(v))); - return result.expressions().get(output).getRoot(); + return expression.getRoot(); } catch (IllegalArgumentException e) { - throw new IllegalArgumentException("Could not import tensorflow model from " + feature, e); + throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e); } } /** - * Returns the specified, existing map value, or the only map value if no key is specified. + * Returns the specified, existing signature, or the only signature if none is specified. * Throws IllegalArgumentException in all other cases. */ - private <T> T chooseOrDefault(String valueDescription, Map<String, T> map, Optional<String> key) { - if ( ! key.isPresent()) { - if (map.size() == 0) - throw new IllegalArgumentException("No " + valueDescription + " are present"); - if (map.size() > 1) - throw new IllegalArgumentException("Model has multiple " + valueDescription + ", but no " + - valueDescription + " argument is specified"); - return map.values().stream().findFirst().get(); + private ImportResult.Signature chooseSignature(ImportResult importResult, Optional<String> signatureName) { + if ( ! signatureName.isPresent()) { + if (importResult.signatures().size() == 0) + throw new IllegalArgumentException("No signatures are present"); + if (importResult.signatures().size() > 1) + throw new IllegalArgumentException("Model has multiple signatures (" + + Joiner.on(", ").join(importResult.signatures().keySet()) + + "), one must be specified " + + "as a second argument to tensorflow()"); + return importResult.signatures().values().stream().findFirst().get(); } else { - T value = map.get(key.get()); - if (value == null) - throw new IllegalArgumentException("Model does not have the specified " + - valueDescription + " '" + key.get() + "'"); - return value; + ImportResult.Signature signature = importResult.signatures().get(signatureName.get()); + if (signature == null) + throw new IllegalArgumentException("Model does not have the specified signature '" + + signatureName.get() + "'"); + return signature; + } + } + + /** + * Returns the specified, existing output expression, or the only output expression if no output name is specified. + * Throws IllegalArgumentException in all other cases. + */ + private RankingExpression chooseOutput(ImportResult.Signature signature, Optional<String> outputName) { + if ( ! outputName.isPresent()) { + if (signature.outputs().size() == 0) + throw new IllegalArgumentException("No signatures are present"); + if (signature.outputs().size() > 1) + throw new IllegalArgumentException(signature + " has multiple outputs (" + + Joiner.on(", ").join(signature.outputs().keySet()) + + "), one must be specified " + + "as a third argument to tensorflow()"); + return signature.outputExpression(signature.outputs().keySet().stream().findFirst().get()); + } + else { + RankingExpression expression = signature.outputExpression(outputName.get()); + if (expression == null) + throw new IllegalArgumentException("Model does not have the specified output '" + + outputName.get() + "'"); + return expression; } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 5ad85ac872c..8fcd821adfd 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -91,8 +91,8 @@ public class RankingExpressionWithTensorFlowTestCase { fail("Expecting exception"); } catch (IllegalArgumentException expected) { - assertEquals("Rank profile 'my_profile' is invalid: Could not import tensorflow model from tensorflow('" + - modelDirectory + "','serving_defaultz'): Model does not have the specified signatures 'serving_defaultz'", + assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from tensorflow('" + + modelDirectory + "','serving_defaultz'): Model does not have the specified signature 'serving_defaultz'", Exceptions.toMessageString(expected)); } } @@ -110,8 +110,8 @@ public class RankingExpressionWithTensorFlowTestCase { fail("Expecting exception"); } catch (IllegalArgumentException expected) { - assertEquals("Rank profile 'my_profile' is invalid: Could not import tensorflow model from tensorflow('" + - modelDirectory + "','serving_default','x'): Model does not have the specified outputs 'x'", + assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from tensorflow('" + + modelDirectory + "','serving_default','x'): Model does not have the specified output 'x'", Exceptions.toMessageString(expected)); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java index 947e6d7a5e1..03c0d87fdd0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java @@ -4,12 +4,9 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; -import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.stream.Collectors; /** * The result of importing a TensorFlow model into Vespa. @@ -26,12 +23,10 @@ public class ImportResult { private final Map<String, TensorType> arguments = new HashMap<>(); private final Map<String, Tensor> constants = new HashMap<>(); private final Map<String, RankingExpression> expressions = new HashMap<>(); - private final List<String> warnings = new ArrayList<>(); void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } void constant(String name, Tensor constant) { constants.put(name, constant); } void expression(String name, RankingExpression expression) { expressions.put(name, expression); } - void warn(String warning) { warnings.add(warning); } /** Returns the given signature. If it does not already exist it is added to this. */ Signature signature(String name) { @@ -51,11 +46,6 @@ public class ImportResult { */ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } - /** Returns an immutable list, in natural sort order of the warnings generated while importing this */ - public List<String> warnings() { - return warnings.stream().sorted().collect(Collectors.toList()); - } - /** Returns an immutable map of the signatures of this */ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } @@ -68,6 +58,7 @@ public class ImportResult { private final String name; private final Map<String, String> inputs = new HashMap<>(); private final Map<String, String> outputs = new HashMap<>(); + private final Map<String, String> skippedOutputs = new HashMap<>(); Signature(String name) { this.name = name; @@ -75,6 +66,7 @@ public class ImportResult { void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } void output(String name, String expressionName) { outputs.put(name, expressionName); } + void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } /** Returns the result this is part of */ ImportResult owner() { return ImportResult.this; } @@ -91,6 +83,12 @@ public class ImportResult { /** Returns an immutable list of the expression names of this */ public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); } + /** + * Returns an immutable list of the outputs of this which could not be imported, + * with a string detailing the reason for each + */ + public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); } + /** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */ public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 4a6551adca7..69781fa915c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -67,8 +67,7 @@ public class TensorFlowImporter { signature.output(outputName, nameOf(output.getValue().getName())); } catch (IllegalArgumentException e) { - result.warn("Skipping output '" + outputName + "' of " + signature + - ": " + Exceptions.toMessageString(e)); + signature.skippedOutput(outputName, Exceptions.toMessageString(e)); } } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java index 0370fc7fc94..8efeceaa034 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java @@ -30,10 +30,6 @@ public class Mnist_SoftmaxTestCase { SavedModelBundle model = SavedModelBundle.load(modelDir, "serve"); ImportResult result = new TensorFlowImporter().importModel(model); - // Check logged messages - result.warnings().forEach(System.err::println); - assertEquals(0, result.warnings().size()); - // Check constants assertEquals(2, result.constants().size()); @@ -71,6 +67,9 @@ public class Mnist_SoftmaxTestCase { "f(a,b)(a + b))", toNonPrimitiveString(output)); + // ... skipped outputs + assertEquals(0, signature.skippedOutputs().size()); + // Test execution assertEqualResult(model, result, "Variable/read"); assertEqualResult(model, result, "Variable_1/read"); |