diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-04-23 11:19:46 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-04-23 11:19:46 +0200 |
commit | bc6e85549a8239c71edffc26927b8e989e682f18 (patch) | |
tree | ed1683b0c724adca6632b518c31f305cfda8a775 | |
parent | 156f56fc90835cc4d254e1f26d1835fe007560ef (diff) | |
parent | 15eb4d021ecd25dc51479c454ad7d67e3885adb1 (diff) |
Merge pull request #5669 from vespa-engine/lesters/tf-model-import-error
Better error message when TensorFlow model import fails
4 files changed, 21 insertions, 1 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 1e6645df792..5790a5294eb 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -110,6 +110,17 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil // Find the specified expression Signature signature = chooseSignature(model, store.arguments().signature()); String output = chooseOutput(signature, store.arguments().output()); + if (signature.skippedOutputs().containsKey(output)) { + String message = "Could not import TensorFlow model output '" + output + "'"; + if (!signature.skippedOutputs().get(output).isEmpty()) { + message += ": " + signature.skippedOutputs().get(output); + } + if (!signature.importWarnings().isEmpty()) { + message += ": " + String.join(", ", signature.importWarnings()); + } + throw new IllegalArgumentException(message); + } + RankingExpression expression = model.expressions().get(output); expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); verifyRequiredMacros(expression, model, profile, queryProfiles); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 217eafd7446..64c777dbfca 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -366,6 +366,9 @@ public class TensorFlowImporter { for (String warning : operation.warnings()) { signature.importWarning(warning); } + for (TensorFlowOperation input : operation.inputs()) { + reportWarnings(input, signature); + } } private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef graph) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java index 977b18b9ab3..b665413a6b2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java @@ -85,7 +85,10 @@ public class OperationMapper { case "stopgradient":return new Identity(modelName, node, inputs, port); case "noop": return new NoOp(modelName, node, inputs, port); } - return new NoOp(modelName, node, inputs, port); + + TensorFlowOperation op = new NoOp(modelName, node, inputs, port); + op.warning("Operation '" + node.getOp() + "' is currently not implemented"); + return op; } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java index a0a3c71145b..3687bba8b85 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java @@ -132,6 +132,9 @@ public abstract class TensorFlowOperation { /** Retrieve the list of warnings produced during its lifetime */ public List<String> warnings() { return Collections.unmodifiableList(importWarnings); } + /** Set an input warning */ + public void warning(String warning) { importWarnings.add(warning); } + boolean verifyInputs(int expected, Function<TensorFlowOperation, Optional<?>> func) { if (!controlInputs.stream().map(func).allMatch(Optional::isPresent)) { return false; |