summaryrefslogtreecommitdiffstats
path: root/config-model/src
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-09-25 11:09:58 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-09-25 11:09:58 +0200
commit99e6c5335971b7c540313c00d72baf2cc1e2ec3e (patch)
tree9a04fc1b696cde3b31d17539f7057528161bc572 /config-model/src
parent17f452d78e97fd6a90b220dc417de29029a026cd (diff)
Return `JsonNode`
Diffstat (limited to 'config-model/src')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java11
1 files changed, 5 insertions, 6 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java
index 7c86267c1b6..084f5dac94d 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java
@@ -29,6 +29,7 @@ import java.util.Map;
public class OnnxModelProbe {
private static final String binary = "vespa-analyze-onnx-model";
+ private static final ObjectMapper jsonParser = new ObjectMapper();
static TensorType probeModel(ApplicationPackage app, Path modelPath, String outputName, Map<String, TensorType> inputTypes) {
TensorType outputType = TensorType.empty;
@@ -41,7 +42,7 @@ public class OnnxModelProbe {
// Otherwise, run vespa-analyze-onnx-model if the model is available
if (outputType.equals(TensorType.empty) && app.getFile(modelPath).exists()) {
String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes);
- String jsonOutput = callVespaAnalyzeOnnxModel(jsonInput);
+ var jsonOutput = callVespaAnalyzeOnnxModel(jsonInput);
outputType = outputTypeFromJson(jsonOutput, outputName);
if ( ! outputType.equals(TensorType.empty)) {
writeProbedOutputType(app, modelPath, contextKey, outputType);
@@ -95,9 +96,7 @@ public class OnnxModelProbe {
return TensorType.empty;
}
- private static TensorType outputTypeFromJson(String json, String outputName) throws IOException {
- ObjectMapper m = new ObjectMapper();
- JsonNode root = m.readTree(json);
+ private static TensorType outputTypeFromJson(JsonNode root, String outputName) throws IOException {
if ( ! root.isObject() || ! root.has("outputs")) {
return TensorType.empty;
}
@@ -123,7 +122,7 @@ public class OnnxModelProbe {
return out.toString();
}
- private static String callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException {
+ private static JsonNode callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException {
StringBuilder output = new StringBuilder();
ProcessBuilder processBuilder = new ProcessBuilder(binary, "--probe-types");
@@ -148,7 +147,7 @@ public class OnnxModelProbe {
throw new IllegalArgumentException("Error from '" + binary + "'. Return code: " + returnCode + ". " +
"Output: '" + output + "'");
}
- return output.toString();
+ return jsonParser.readTree(output.toString());
}
}