diff options
author | Lester Solbakken <lesters@oath.com> | 2022-02-23 15:37:43 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2022-03-23 11:35:30 +0100 |
commit | fe8095f3fcead730ce8cd4fdb02802a936029720 (patch) | |
tree | d7065b26508a3874e77bf649ec7aa027eecec62b /config-model | |
parent | 60e49842d5b6625b98edd15cfd5c70bbb82425cc (diff) |
Don't propagate exception, but return empty tensor type for fallback
Diffstat (limited to 'config-model')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java | 6 | ||||
-rw-r--r-- | config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java | 6 |
2 files changed, 8 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java index 7230d0ad59b..2ef81e3f1fa 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java @@ -43,11 +43,13 @@ public class OnnxModelProbe { String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes); String jsonOutput = callVespaAnalyzeOnnxModel(jsonInput); outputType = outputTypeFromJson(jsonOutput, outputName); - writeProbedOutputType(app, modelPath, contextKey, outputType); + if ( ! outputType.equals(TensorType.empty)) { + writeProbedOutputType(app, modelPath, contextKey, outputType); + } } } catch (IOException | InterruptedException e) { - throw new IllegalArgumentException("Unable to probe ONNX model", e); + e.printStackTrace(System.err); } return outputType; diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java index 55e2f6f018c..6c4c919a229 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java @@ -30,7 +30,9 @@ public class OnnxModelProbeTest { TensorType expected = TensorType.fromSpec("tensor<float>(d0[1],d1[2],d2[2])"); // Can't test model probing directly as 'vespa-analyze-onnx-model' is not available - + TensorType outputType = OnnxModelProbe.probeModel(app, modelPath, output, inputTypes); + assertEquals(outputType, TensorType.empty); + OnnxModelProbe.writeProbedOutputType(app, modelPath, output, inputTypes, expected); // Test loading from generated info @@ -38,7 +40,7 @@ public class OnnxModelProbeTest { IOUtils.copyDirectory(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); app = FilesApplicationPackage.fromFile(storedAppDir.toFile()); - TensorType outputType = OnnxModelProbe.probeModel(app, modelPath, output, inputTypes); + outputType = OnnxModelProbe.probeModel(app, modelPath, output, inputTypes); assertEquals(outputType, expected); } finally { |