diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/ml')
5 files changed, 9 insertions, 22 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index ea7d1620fd9..f007065a0c2 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.ml; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java index b52fd060a1c..f1c1587552e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.ml; import com.yahoo.path.Path; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java index 1e5e95f06d5..79b1fe46729 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.ml; import com.yahoo.path.Path; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java index 8edd446b209..c622b4d58b4 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.ml; import com.fasterxml.jackson.core.JsonEncoding; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java index 39a8e16fad5..0f89a839a26 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java @@ -1,3 +1,4 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.ml; import com.fasterxml.jackson.core.JsonEncoding; @@ -7,6 +8,7 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.api.OnnxMemoryStats; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.tensor.TensorType; @@ -44,7 +46,7 @@ public class OnnxModelProbe { String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes); var jsonOutput = callVespaAnalyzeOnnxModel(jsonInput); outputType = outputTypeFromJson(jsonOutput, outputName); - writeMemoryStats(app, modelPath, MemoryStats.fromJson(jsonOutput)); + writeMemoryStats(app, modelPath, OnnxMemoryStats.fromJson(jsonOutput)); if ( ! outputType.equals(TensorType.empty)) { writeProbedOutputType(app, modelPath, contextKey, outputType); } @@ -55,16 +57,11 @@ public class OnnxModelProbe { return outputType; } - private static void writeMemoryStats(ApplicationPackage app, Path modelPath, MemoryStats memoryStats) throws IOException { - String path = app.getFileReference(memoryStatsPath(modelPath)).getAbsolutePath(); + private static void writeMemoryStats(ApplicationPackage app, Path modelPath, OnnxMemoryStats memoryStats) throws IOException { + String path = app.getFileReference(OnnxMemoryStats.memoryStatsFilePath(modelPath)).getAbsolutePath(); IOUtils.writeFile(path, memoryStats.toJson().toPrettyString(), false); } - private static Path memoryStatsPath(Path modelPath) { - var fileName = OnnxModelInfo.asValidIdentifier(modelPath.getRelative()) + ".memory_stats"; - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName); - } - private static String createContextKey(String onnxName, Map<String, TensorType> inputTypes) { StringBuilder key = new StringBuilder().append(onnxName).append(":"); inputTypes.entrySet().stream().sorted(Map.Entry.comparingByKey()) @@ -160,14 +157,4 @@ public class OnnxModelProbe { } return jsonParser.readTree(output.toString()); } - - public record MemoryStats(long vmSize, long vmRss) { - static MemoryStats fromJson(JsonNode json) { - return new MemoryStats(json.get("vm_size").asLong(), json.get("vm_rss").asLong()); - } - JsonNode toJson() { - return jsonParser.createObjectNode().put("vm_size", vmSize).put("vm_rss", vmRss); - } - } - } |