From fe8095f3fcead730ce8cd4fdb02802a936029720 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Wed, 23 Feb 2022 15:37:43 +0100 Subject: Don't propagate exception, but return empty tensor type for fallback --- .../src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java | 6 ++++-- .../src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) (limited to 'config-model/src') diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java index 7230d0ad59b..2ef81e3f1fa 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java @@ -43,11 +43,13 @@ public class OnnxModelProbe { String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes); String jsonOutput = callVespaAnalyzeOnnxModel(jsonInput); outputType = outputTypeFromJson(jsonOutput, outputName); - writeProbedOutputType(app, modelPath, contextKey, outputType); + if ( ! outputType.equals(TensorType.empty)) { + writeProbedOutputType(app, modelPath, contextKey, outputType); + } } } catch (IOException | InterruptedException e) { - throw new IllegalArgumentException("Unable to probe ONNX model", e); + e.printStackTrace(System.err); } return outputType; diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java index 55e2f6f018c..6c4c919a229 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java @@ -30,7 +30,9 @@ public class OnnxModelProbeTest { TensorType expected = TensorType.fromSpec("tensor(d0[1],d1[2],d2[2])"); // Can't test model probing directly as 'vespa-analyze-onnx-model' is not available - + TensorType outputType = OnnxModelProbe.probeModel(app, modelPath, output, inputTypes); + assertEquals(outputType, TensorType.empty); + OnnxModelProbe.writeProbedOutputType(app, modelPath, output, inputTypes, expected); // Test loading from generated info @@ -38,7 +40,7 @@ public class OnnxModelProbeTest { IOUtils.copyDirectory(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); app = FilesApplicationPackage.fromFile(storedAppDir.toFile()); - TensorType outputType = OnnxModelProbe.probeModel(app, modelPath, output, inputTypes); + outputType = OnnxModelProbe.probeModel(app, modelPath, output, inputTypes); assertEquals(outputType, expected); } finally { -- cgit v1.2.3