summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2022-02-23 15:37:43 +0100
committerLester Solbakken <lesters@oath.com>2022-03-23 11:35:30 +0100
commitfe8095f3fcead730ce8cd4fdb02802a936029720 (patch)
treed7065b26508a3874e77bf649ec7aa027eecec62b /config-model
parent60e49842d5b6625b98edd15cfd5c70bbb82425cc (diff)
Don't propagate exception, but return empty tensor type for fallback
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java6
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java6
2 files changed, 8 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java
index 7230d0ad59b..2ef81e3f1fa 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java
@@ -43,11 +43,13 @@ public class OnnxModelProbe {
String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes);
String jsonOutput = callVespaAnalyzeOnnxModel(jsonInput);
outputType = outputTypeFromJson(jsonOutput, outputName);
- writeProbedOutputType(app, modelPath, contextKey, outputType);
+ if ( ! outputType.equals(TensorType.empty)) {
+ writeProbedOutputType(app, modelPath, contextKey, outputType);
+ }
}
} catch (IOException | InterruptedException e) {
- throw new IllegalArgumentException("Unable to probe ONNX model", e);
+ e.printStackTrace(System.err);
}
return outputType;
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java
index 55e2f6f018c..6c4c919a229 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java
@@ -30,7 +30,9 @@ public class OnnxModelProbeTest {
TensorType expected = TensorType.fromSpec("tensor<float>(d0[1],d1[2],d2[2])");
// Can't test model probing directly as 'vespa-analyze-onnx-model' is not available
-
+ TensorType outputType = OnnxModelProbe.probeModel(app, modelPath, output, inputTypes);
+ assertEquals(outputType, TensorType.empty);
+
OnnxModelProbe.writeProbedOutputType(app, modelPath, output, inputTypes, expected);
// Test loading from generated info
@@ -38,7 +40,7 @@ public class OnnxModelProbeTest {
IOUtils.copyDirectory(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
app = FilesApplicationPackage.fromFile(storedAppDir.toFile());
- TensorType outputType = OnnxModelProbe.probeModel(app, modelPath, output, inputTypes);
+ outputType = OnnxModelProbe.probeModel(app, modelPath, output, inputTypes);
assertEquals(outputType, expected);
} finally {