From a4fc968c0942b5813b2cbe140cbb688631629b9a Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Mon, 16 Dec 2019 09:50:08 +0100 Subject: Add better error messages for unsupported ONNX operations --- .../importer/onnx/GraphImporter.java | 21 +++++++++++++++++++++ .../importer/onnx/Tf2OnnxImportTestCase.java | 2 +- 2 files changed, 22 insertions(+), 1 deletion(-) (limited to 'model-integration') diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index 55f5d979ea8..d42338deaf8 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -27,8 +27,10 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.functions.ScalarFunctions; import onnx.Onnx; +import java.util.HashSet; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; /** @@ -123,6 +125,7 @@ class GraphImporter { IntermediateGraph intermediateGraph = new IntermediateGraph(modelName); importOperations(onnxGraph, intermediateGraph); + verifyNoWarnings(intermediateGraph); verifyOutputTypes(onnxGraph, intermediateGraph); return intermediateGraph; @@ -234,6 +237,16 @@ class GraphImporter { .collect(Collectors.toList()); } + private static void verifyNoWarnings(IntermediateGraph intermediateGraph) { + for (java.util.Map.Entry output : intermediateGraph.outputs(intermediateGraph.defaultSignature()).entrySet()) { + IntermediateOperation operation = intermediateGraph.get(output.getValue()); + Set warnings = getWarnings(operation); + if (warnings.size() > 0) { + throw new IllegalArgumentException("Could not import " + intermediateGraph.name() + ": " + String.join("\n", warnings)); + } + } + } + private static void verifyOutputTypes(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) { for (java.util.Map.Entry output : intermediateGraph.outputs(intermediateGraph.defaultSignature()).entrySet()) { IntermediateOperation operation = intermediateGraph.get(output.getValue()); @@ -284,4 +297,12 @@ class GraphImporter { "Either no explicit name given or no single output name."); } + private static Set getWarnings(IntermediateOperation op) { + Set warnings = new HashSet<>(op.warnings()); + for (IntermediateOperation input : op.inputs()) { + warnings.addAll(getWarnings(input)); + } + return warnings; + } + } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/Tf2OnnxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/Tf2OnnxImportTestCase.java index 4250fee4d20..c7245fe53e8 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/Tf2OnnxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/Tf2OnnxImportTestCase.java @@ -39,7 +39,7 @@ public class Tf2OnnxImportTestCase extends TestableModel { @Ignore public void testOnnxConversionAndImport() { Report report = new Report(); - for (int i = 11; i < 12; ++i) { + for (int i = 1; i < 12; ++i) { testModelsWithOpset(report, i); } System.out.println(report); -- cgit v1.2.3