summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-12-16 09:50:08 +0100
committerLester Solbakken <lesters@oath.com>2019-12-16 09:50:08 +0100
commita4fc968c0942b5813b2cbe140cbb688631629b9a (patch)
tree4be0d09fb697325330deb18d7eadb10c8f6cc71e /model-integration
parentd03b908f4105413a1833b19df6959bf644702233 (diff)
Add better error messages for unsupported ONNX operations
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java21
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/Tf2OnnxImportTestCase.java2
2 files changed, 22 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index 55f5d979ea8..d42338deaf8 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -27,8 +27,10 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.functions.ScalarFunctions;
import onnx.Onnx;
+import java.util.HashSet;
import java.util.List;
import java.util.Optional;
+import java.util.Set;
import java.util.stream.Collectors;
/**
@@ -123,6 +125,7 @@ class GraphImporter {
IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
importOperations(onnxGraph, intermediateGraph);
+ verifyNoWarnings(intermediateGraph);
verifyOutputTypes(onnxGraph, intermediateGraph);
return intermediateGraph;
@@ -234,6 +237,16 @@ class GraphImporter {
.collect(Collectors.toList());
}
+ private static void verifyNoWarnings(IntermediateGraph intermediateGraph) {
+ for (java.util.Map.Entry<String, String> output : intermediateGraph.outputs(intermediateGraph.defaultSignature()).entrySet()) {
+ IntermediateOperation operation = intermediateGraph.get(output.getValue());
+ Set<String> warnings = getWarnings(operation);
+ if (warnings.size() > 0) {
+ throw new IllegalArgumentException("Could not import " + intermediateGraph.name() + ": " + String.join("\n", warnings));
+ }
+ }
+ }
+
private static void verifyOutputTypes(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) {
for (java.util.Map.Entry<String, String> output : intermediateGraph.outputs(intermediateGraph.defaultSignature()).entrySet()) {
IntermediateOperation operation = intermediateGraph.get(output.getValue());
@@ -284,4 +297,12 @@ class GraphImporter {
"Either no explicit name given or no single output name.");
}
+ private static Set<String> getWarnings(IntermediateOperation op) {
+ Set<String> warnings = new HashSet<>(op.warnings());
+ for (IntermediateOperation input : op.inputs()) {
+ warnings.addAll(getWarnings(input));
+ }
+ return warnings;
+ }
+
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/Tf2OnnxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/Tf2OnnxImportTestCase.java
index 4250fee4d20..c7245fe53e8 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/Tf2OnnxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/Tf2OnnxImportTestCase.java
@@ -39,7 +39,7 @@ public class Tf2OnnxImportTestCase extends TestableModel {
@Ignore
public void testOnnxConversionAndImport() {
Report report = new Report();
- for (int i = 11; i < 12; ++i) {
+ for (int i = 1; i < 12; ++i) {
testModelsWithOpset(report, i);
}
System.out.println(report);