aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java23
1 files changed, 5 insertions, 18 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java
index 39a8e16fad5..0f89a839a26 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java
@@ -1,3 +1,4 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.ml;
import com.fasterxml.jackson.core.JsonEncoding;
@@ -7,6 +8,7 @@ import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.api.OnnxMemoryStats;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.tensor.TensorType;
@@ -44,7 +46,7 @@ public class OnnxModelProbe {
String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes);
var jsonOutput = callVespaAnalyzeOnnxModel(jsonInput);
outputType = outputTypeFromJson(jsonOutput, outputName);
- writeMemoryStats(app, modelPath, MemoryStats.fromJson(jsonOutput));
+ writeMemoryStats(app, modelPath, OnnxMemoryStats.fromJson(jsonOutput));
if ( ! outputType.equals(TensorType.empty)) {
writeProbedOutputType(app, modelPath, contextKey, outputType);
}
@@ -55,16 +57,11 @@ public class OnnxModelProbe {
return outputType;
}
- private static void writeMemoryStats(ApplicationPackage app, Path modelPath, MemoryStats memoryStats) throws IOException {
- String path = app.getFileReference(memoryStatsPath(modelPath)).getAbsolutePath();
+ private static void writeMemoryStats(ApplicationPackage app, Path modelPath, OnnxMemoryStats memoryStats) throws IOException {
+ String path = app.getFileReference(OnnxMemoryStats.memoryStatsFilePath(modelPath)).getAbsolutePath();
IOUtils.writeFile(path, memoryStats.toJson().toPrettyString(), false);
}
- private static Path memoryStatsPath(Path modelPath) {
- var fileName = OnnxModelInfo.asValidIdentifier(modelPath.getRelative()) + ".memory_stats";
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName);
- }
-
private static String createContextKey(String onnxName, Map<String, TensorType> inputTypes) {
StringBuilder key = new StringBuilder().append(onnxName).append(":");
inputTypes.entrySet().stream().sorted(Map.Entry.comparingByKey())
@@ -160,14 +157,4 @@ public class OnnxModelProbe {
}
return jsonParser.readTree(output.toString());
}
-
- public record MemoryStats(long vmSize, long vmRss) {
- static MemoryStats fromJson(JsonNode json) {
- return new MemoryStats(json.get("vm_size").asLong(), json.get("vm_rss").asLong());
- }
- JsonNode toJson() {
- return jsonParser.createObjectNode().put("vm_size", vmSize).put("vm_rss", vmRss);
- }
- }
-
}