diff options
author | Bjørn Christian Seime <bjorn.christian@seime.no> | 2023-09-25 14:53:50 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-25 14:53:50 +0200 |
commit | 27472f94770a6644d44b765cedf802d8bb38ac03 (patch) | |
tree | d219d2dc12f34f5e93ff8b3c03c1fc6ac8abf238 /config-model | |
parent | 7facdd6177063f772c497000b9c12e4653a2db83 (diff) | |
parent | 2a537e9ce9223110ca2bbedd7e88139c24524049 (diff) |
Merge pull request #28645 from vespa-engine/bjorncs/analyze-model
Bjorncs/analyze model
Diffstat (limited to 'config-model')
4 files changed, 63 insertions, 12 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java b/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java index 76733872882..9794cfe4ad7 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java @@ -4,8 +4,10 @@ package com.yahoo.vespa.model; import com.yahoo.config.ModelReference; import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.api.OnnxModelCost; +import com.yahoo.vespa.model.ml.OnnxModelProbe; import java.io.IOException; import java.net.URI; @@ -29,16 +31,18 @@ import static com.yahoo.yolean.Exceptions.uncheck; public class DefaultOnnxModelCost implements OnnxModelCost { @Override - public Calculator newCalculator(DeployLogger logger) { - return new CalculatorImpl(logger); + public Calculator newCalculator(ApplicationPackage appPkg, DeployLogger logger) { + return new CalculatorImpl(appPkg, logger); } private static class CalculatorImpl implements Calculator { private final DeployLogger log; + private final ApplicationPackage appPkg; private final ConcurrentMap<String, Long> modelCost = new ConcurrentHashMap<>(); - private CalculatorImpl(DeployLogger log) { + private CalculatorImpl(ApplicationPackage appPkg, DeployLogger log) { + this.appPkg = appPkg; this.log = log; } @@ -52,7 +56,17 @@ public class DefaultOnnxModelCost implements OnnxModelCost { String path = f.getPath().getRelative(); if (alreadyAnalyzed(path)) return; log.log(Level.FINE, () -> "Register model '%s'".formatted(path)); - deductJvmHeapSizeWithModelCost(f.exists() ? f.getSize() : 0, path); + if (f.exists()) { + var memoryStats = OnnxModelProbe.probeMemoryStats(appPkg, f.getPath()).orElse(null); + if (memoryStats != null) { + log.log(Level.FINE, () -> "Register model '%s' with memory stats: %s".formatted(path, memoryStats)); + deductJvmHeapSizeWithModelCost(f.getSize(), memoryStats, path); + } else { + deductJvmHeapSizeWithModelCost(f.getSize(), path); + } + } else { + deductJvmHeapSizeWithModelCost(0, path); + } } @Override @@ -92,6 +106,13 @@ public class DefaultOnnxModelCost implements OnnxModelCost { modelCost.put(source, estimatedCost); } + private void deductJvmHeapSizeWithModelCost(long size, OnnxModelProbe.MemoryStats stats, String source) { + long estimatedCost = (long)(1.1D * stats.vmSize()); + log.log(Level.FINE, () -> + "Estimated %s footprint for model of size %s ('%s')".formatted(mb(estimatedCost), mb(size), source)); + modelCost.put(source, estimatedCost); + } + private boolean alreadyAnalyzed(String source) { return modelCost.containsKey(source); } private static String mb(long bytes) { return "%dMB".formatted(bytes / (1024*1024)); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java index 2227831a8a0..be1b952a834 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java @@ -136,7 +136,8 @@ public final class ApplicationContainerCluster extends ContainerCluster<Applicat heapSizePercentageOfAvailableMemory = deployState.featureFlags().heapSizePercentage() > 0 ? Math.min(99, deployState.featureFlags().heapSizePercentage()) : defaultHeapSizePercentageOfAvailableMemory; - onnxModelCost = deployState.onnxModelCost().newCalculator(deployState.getDeployLogger()); + onnxModelCost = deployState.onnxModelCost().newCalculator( + deployState.getApplicationPackage(), deployState.getDeployLogger()); logger = deployState.getDeployLogger(); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java index 7c86267c1b6..38dda3e29ff 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java @@ -18,6 +18,9 @@ import java.io.InputStream; import java.io.OutputStream; import java.nio.charset.StandardCharsets; import java.util.Map; +import java.util.Optional; + +import static com.yahoo.yolean.Exceptions.uncheck; /** * Defers to 'vespa-analyze-onnx-model' to determine the output type given @@ -29,6 +32,7 @@ import java.util.Map; public class OnnxModelProbe { private static final String binary = "vespa-analyze-onnx-model"; + private static final ObjectMapper jsonParser = new ObjectMapper(); static TensorType probeModel(ApplicationPackage app, Path modelPath, String outputName, Map<String, TensorType> inputTypes) { TensorType outputType = TensorType.empty; @@ -41,8 +45,9 @@ public class OnnxModelProbe { // Otherwise, run vespa-analyze-onnx-model if the model is available if (outputType.equals(TensorType.empty) && app.getFile(modelPath).exists()) { String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes); - String jsonOutput = callVespaAnalyzeOnnxModel(jsonInput); + var jsonOutput = callVespaAnalyzeOnnxModel(jsonInput); outputType = outputTypeFromJson(jsonOutput, outputName); + writeMemoryStats(app, modelPath, MemoryStats.fromJson(jsonOutput)); if ( ! outputType.equals(TensorType.empty)) { writeProbedOutputType(app, modelPath, contextKey, outputType); } @@ -53,6 +58,22 @@ public class OnnxModelProbe { return outputType; } + public static Optional<MemoryStats> probeMemoryStats(ApplicationPackage app, Path modelPath) { + return Optional.of(app.getFile(memoryStatsPath(modelPath))) + .filter(ApplicationFile::exists) + .map(file -> MemoryStats.fromJson(uncheck(() -> jsonParser.readTree(file.createReader())))); + } + + private static void writeMemoryStats(ApplicationPackage app, Path modelPath, MemoryStats memoryStats) throws IOException { + String path = app.getFileReference(memoryStatsPath(modelPath)).getAbsolutePath(); + IOUtils.writeFile(path, memoryStats.toJson().toPrettyString(), false); + } + + private static Path memoryStatsPath(Path modelPath) { + var fileName = OnnxModelInfo.asValidIdentifier(modelPath.getRelative()) + ".memory_stats"; + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName); + } + private static String createContextKey(String onnxName, Map<String, TensorType> inputTypes) { StringBuilder key = new StringBuilder().append(onnxName).append(":"); inputTypes.entrySet().stream().sorted(Map.Entry.comparingByKey()) @@ -95,9 +116,7 @@ public class OnnxModelProbe { return TensorType.empty; } - private static TensorType outputTypeFromJson(String json, String outputName) throws IOException { - ObjectMapper m = new ObjectMapper(); - JsonNode root = m.readTree(json); + private static TensorType outputTypeFromJson(JsonNode root, String outputName) throws IOException { if ( ! root.isObject() || ! root.has("outputs")) { return TensorType.empty; } @@ -123,7 +142,7 @@ public class OnnxModelProbe { return out.toString(); } - private static String callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException { + private static JsonNode callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException { StringBuilder output = new StringBuilder(); ProcessBuilder processBuilder = new ProcessBuilder(binary, "--probe-types"); @@ -148,7 +167,16 @@ public class OnnxModelProbe { throw new IllegalArgumentException("Error from '" + binary + "'. Return code: " + returnCode + ". " + "Output: '" + output + "'"); } - return output.toString(); + return jsonParser.readTree(output.toString()); + } + + public record MemoryStats(long vmSize, long vmRss) { + static MemoryStats fromJson(JsonNode json) { + return new MemoryStats(json.get("vm_size").asLong(), json.get("vm_rss").asLong()); + } + JsonNode toJson() { + return jsonParser.createObjectNode().put("vm_size", vmSize).put("vm_rss", vmRss); + } } } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java index 245887a5d03..447614b8396 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java @@ -4,6 +4,7 @@ package com.yahoo.vespa.model.application.validation; import com.yahoo.config.ModelReference; import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.NullConfigModelRegistry; import com.yahoo.config.model.api.OnnxModelCost; @@ -112,7 +113,7 @@ class JvmHeapSizeValidatorTest { ModelCostDummy(long modelCost) { this.modelCost = modelCost; } - @Override public Calculator newCalculator(DeployLogger logger) { return this; } + @Override public Calculator newCalculator(ApplicationPackage appPkg, DeployLogger logger) { return this; } @Override public long aggregatedModelCostInBytes() { return totalCost.get(); } @Override public void registerModel(ApplicationFile path) {} |