summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-04-23 11:19:46 +0200
committerGitHub <noreply@github.com>2018-04-23 11:19:46 +0200
commitbc6e85549a8239c71edffc26927b8e989e682f18 (patch)
treeed1683b0c724adca6632b518c31f305cfda8a775
parent156f56fc90835cc4d254e1f26d1835fe007560ef (diff)
parent15eb4d021ecd25dc51479c454ad7d67e3885adb1 (diff)
Merge pull request #5669 from vespa-engine/lesters/tf-model-import-error
Better error message when TensorFlow model import fails
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java11
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java3
4 files changed, 21 insertions, 1 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 1e6645df792..5790a5294eb 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -110,6 +110,17 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
// Find the specified expression
Signature signature = chooseSignature(model, store.arguments().signature());
String output = chooseOutput(signature, store.arguments().output());
+ if (signature.skippedOutputs().containsKey(output)) {
+ String message = "Could not import TensorFlow model output '" + output + "'";
+ if (!signature.skippedOutputs().get(output).isEmpty()) {
+ message += ": " + signature.skippedOutputs().get(output);
+ }
+ if (!signature.importWarnings().isEmpty()) {
+ message += ": " + String.join(", ", signature.importWarnings());
+ }
+ throw new IllegalArgumentException(message);
+ }
+
RankingExpression expression = model.expressions().get(output);
expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
verifyRequiredMacros(expression, model, profile, queryProfiles);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 217eafd7446..64c777dbfca 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -366,6 +366,9 @@ public class TensorFlowImporter {
for (String warning : operation.warnings()) {
signature.importWarning(warning);
}
+ for (TensorFlowOperation input : operation.inputs()) {
+ reportWarnings(input, signature);
+ }
}
private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef graph) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
index 977b18b9ab3..b665413a6b2 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
@@ -85,7 +85,10 @@ public class OperationMapper {
case "stopgradient":return new Identity(modelName, node, inputs, port);
case "noop": return new NoOp(modelName, node, inputs, port);
}
- return new NoOp(modelName, node, inputs, port);
+
+ TensorFlowOperation op = new NoOp(modelName, node, inputs, port);
+ op.warning("Operation '" + node.getOp() + "' is currently not implemented");
+ return op;
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
index a0a3c71145b..3687bba8b85 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
@@ -132,6 +132,9 @@ public abstract class TensorFlowOperation {
/** Retrieve the list of warnings produced during its lifetime */
public List<String> warnings() { return Collections.unmodifiableList(importWarnings); }
+ /** Set an input warning */
+ public void warning(String warning) { importWarnings.add(warning); }
+
boolean verifyInputs(int expected, Function<TensorFlowOperation, Optional<?>> func) {
if (!controlInputs.stream().map(func).allMatch(Optional::isPresent)) {
return false;