summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2019-12-18 10:28:11 +0100
committerGitHub <noreply@github.com>2019-12-18 10:28:11 +0100
commit2041886aa197dff3a68abe5e821401f37adaa6f4 (patch)
treebd104f117856a5d7aa429be6068c5a0893aed0be /model-integration
parent2c5f5037758f1efed35a4b6a56e301a4bc1379d7 (diff)
parenta4fc968c0942b5813b2cbe140cbb688631629b9a (diff)
Merge pull request #11567 from vespa-engine/lesters/better-error-message-for-unsupported-onnx-operations
Add better error messages for unsupported ONNX operations
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java21
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/Tf2OnnxImportTestCase.java2
2 files changed, 22 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index 55f5d979ea8..d42338deaf8 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -27,8 +27,10 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.functions.ScalarFunctions;
import onnx.Onnx;
+import java.util.HashSet;
import java.util.List;
import java.util.Optional;
+import java.util.Set;
import java.util.stream.Collectors;
/**
@@ -123,6 +125,7 @@ class GraphImporter {
IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
importOperations(onnxGraph, intermediateGraph);
+ verifyNoWarnings(intermediateGraph);
verifyOutputTypes(onnxGraph, intermediateGraph);
return intermediateGraph;
@@ -234,6 +237,16 @@ class GraphImporter {
.collect(Collectors.toList());
}
+ private static void verifyNoWarnings(IntermediateGraph intermediateGraph) {
+ for (java.util.Map.Entry<String, String> output : intermediateGraph.outputs(intermediateGraph.defaultSignature()).entrySet()) {
+ IntermediateOperation operation = intermediateGraph.get(output.getValue());
+ Set<String> warnings = getWarnings(operation);
+ if (warnings.size() > 0) {
+ throw new IllegalArgumentException("Could not import " + intermediateGraph.name() + ": " + String.join("\n", warnings));
+ }
+ }
+ }
+
private static void verifyOutputTypes(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) {
for (java.util.Map.Entry<String, String> output : intermediateGraph.outputs(intermediateGraph.defaultSignature()).entrySet()) {
IntermediateOperation operation = intermediateGraph.get(output.getValue());
@@ -284,4 +297,12 @@ class GraphImporter {
"Either no explicit name given or no single output name.");
}
+ private static Set<String> getWarnings(IntermediateOperation op) {
+ Set<String> warnings = new HashSet<>(op.warnings());
+ for (IntermediateOperation input : op.inputs()) {
+ warnings.addAll(getWarnings(input));
+ }
+ return warnings;
+ }
+
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/Tf2OnnxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/Tf2OnnxImportTestCase.java
index 4250fee4d20..c7245fe53e8 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/Tf2OnnxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/Tf2OnnxImportTestCase.java
@@ -39,7 +39,7 @@ public class Tf2OnnxImportTestCase extends TestableModel {
@Ignore
public void testOnnxConversionAndImport() {
Report report = new Report();
- for (int i = 11; i < 12; ++i) {
+ for (int i = 1; i < 12; ++i) {
testModelsWithOpset(report, i);
}
System.out.println(report);