summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorn.christian@seime.no>2023-09-25 14:53:50 +0200
committerGitHub <noreply@github.com>2023-09-25 14:53:50 +0200
commit27472f94770a6644d44b765cedf802d8bb38ac03 (patch)
treed219d2dc12f34f5e93ff8b3c03c1fc6ac8abf238 /config-model
parent7facdd6177063f772c497000b9c12e4653a2db83 (diff)
parent2a537e9ce9223110ca2bbedd7e88139c24524049 (diff)
Merge pull request #28645 from vespa-engine/bjorncs/analyze-model
Bjorncs/analyze model
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java29
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java3
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java40
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java3
4 files changed, 63 insertions, 12 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java b/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java
index 76733872882..9794cfe4ad7 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java
@@ -4,8 +4,10 @@ package com.yahoo.vespa.model;
import com.yahoo.config.ModelReference;
import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.model.api.OnnxModelCost;
+import com.yahoo.vespa.model.ml.OnnxModelProbe;
import java.io.IOException;
import java.net.URI;
@@ -29,16 +31,18 @@ import static com.yahoo.yolean.Exceptions.uncheck;
public class DefaultOnnxModelCost implements OnnxModelCost {
@Override
- public Calculator newCalculator(DeployLogger logger) {
- return new CalculatorImpl(logger);
+ public Calculator newCalculator(ApplicationPackage appPkg, DeployLogger logger) {
+ return new CalculatorImpl(appPkg, logger);
}
private static class CalculatorImpl implements Calculator {
private final DeployLogger log;
+ private final ApplicationPackage appPkg;
private final ConcurrentMap<String, Long> modelCost = new ConcurrentHashMap<>();
- private CalculatorImpl(DeployLogger log) {
+ private CalculatorImpl(ApplicationPackage appPkg, DeployLogger log) {
+ this.appPkg = appPkg;
this.log = log;
}
@@ -52,7 +56,17 @@ public class DefaultOnnxModelCost implements OnnxModelCost {
String path = f.getPath().getRelative();
if (alreadyAnalyzed(path)) return;
log.log(Level.FINE, () -> "Register model '%s'".formatted(path));
- deductJvmHeapSizeWithModelCost(f.exists() ? f.getSize() : 0, path);
+ if (f.exists()) {
+ var memoryStats = OnnxModelProbe.probeMemoryStats(appPkg, f.getPath()).orElse(null);
+ if (memoryStats != null) {
+ log.log(Level.FINE, () -> "Register model '%s' with memory stats: %s".formatted(path, memoryStats));
+ deductJvmHeapSizeWithModelCost(f.getSize(), memoryStats, path);
+ } else {
+ deductJvmHeapSizeWithModelCost(f.getSize(), path);
+ }
+ } else {
+ deductJvmHeapSizeWithModelCost(0, path);
+ }
}
@Override
@@ -92,6 +106,13 @@ public class DefaultOnnxModelCost implements OnnxModelCost {
modelCost.put(source, estimatedCost);
}
+ private void deductJvmHeapSizeWithModelCost(long size, OnnxModelProbe.MemoryStats stats, String source) {
+ long estimatedCost = (long)(1.1D * stats.vmSize());
+ log.log(Level.FINE, () ->
+ "Estimated %s footprint for model of size %s ('%s')".formatted(mb(estimatedCost), mb(size), source));
+ modelCost.put(source, estimatedCost);
+ }
+
private boolean alreadyAnalyzed(String source) { return modelCost.containsKey(source); }
private static String mb(long bytes) { return "%dMB".formatted(bytes / (1024*1024)); }
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java
index 2227831a8a0..be1b952a834 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java
@@ -136,7 +136,8 @@ public final class ApplicationContainerCluster extends ContainerCluster<Applicat
heapSizePercentageOfAvailableMemory = deployState.featureFlags().heapSizePercentage() > 0
? Math.min(99, deployState.featureFlags().heapSizePercentage())
: defaultHeapSizePercentageOfAvailableMemory;
- onnxModelCost = deployState.onnxModelCost().newCalculator(deployState.getDeployLogger());
+ onnxModelCost = deployState.onnxModelCost().newCalculator(
+ deployState.getApplicationPackage(), deployState.getDeployLogger());
logger = deployState.getDeployLogger();
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java
index 7c86267c1b6..38dda3e29ff 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java
@@ -18,6 +18,9 @@ import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.Map;
+import java.util.Optional;
+
+import static com.yahoo.yolean.Exceptions.uncheck;
/**
* Defers to 'vespa-analyze-onnx-model' to determine the output type given
@@ -29,6 +32,7 @@ import java.util.Map;
public class OnnxModelProbe {
private static final String binary = "vespa-analyze-onnx-model";
+ private static final ObjectMapper jsonParser = new ObjectMapper();
static TensorType probeModel(ApplicationPackage app, Path modelPath, String outputName, Map<String, TensorType> inputTypes) {
TensorType outputType = TensorType.empty;
@@ -41,8 +45,9 @@ public class OnnxModelProbe {
// Otherwise, run vespa-analyze-onnx-model if the model is available
if (outputType.equals(TensorType.empty) && app.getFile(modelPath).exists()) {
String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes);
- String jsonOutput = callVespaAnalyzeOnnxModel(jsonInput);
+ var jsonOutput = callVespaAnalyzeOnnxModel(jsonInput);
outputType = outputTypeFromJson(jsonOutput, outputName);
+ writeMemoryStats(app, modelPath, MemoryStats.fromJson(jsonOutput));
if ( ! outputType.equals(TensorType.empty)) {
writeProbedOutputType(app, modelPath, contextKey, outputType);
}
@@ -53,6 +58,22 @@ public class OnnxModelProbe {
return outputType;
}
+ public static Optional<MemoryStats> probeMemoryStats(ApplicationPackage app, Path modelPath) {
+ return Optional.of(app.getFile(memoryStatsPath(modelPath)))
+ .filter(ApplicationFile::exists)
+ .map(file -> MemoryStats.fromJson(uncheck(() -> jsonParser.readTree(file.createReader()))));
+ }
+
+ private static void writeMemoryStats(ApplicationPackage app, Path modelPath, MemoryStats memoryStats) throws IOException {
+ String path = app.getFileReference(memoryStatsPath(modelPath)).getAbsolutePath();
+ IOUtils.writeFile(path, memoryStats.toJson().toPrettyString(), false);
+ }
+
+ private static Path memoryStatsPath(Path modelPath) {
+ var fileName = OnnxModelInfo.asValidIdentifier(modelPath.getRelative()) + ".memory_stats";
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName);
+ }
+
private static String createContextKey(String onnxName, Map<String, TensorType> inputTypes) {
StringBuilder key = new StringBuilder().append(onnxName).append(":");
inputTypes.entrySet().stream().sorted(Map.Entry.comparingByKey())
@@ -95,9 +116,7 @@ public class OnnxModelProbe {
return TensorType.empty;
}
- private static TensorType outputTypeFromJson(String json, String outputName) throws IOException {
- ObjectMapper m = new ObjectMapper();
- JsonNode root = m.readTree(json);
+ private static TensorType outputTypeFromJson(JsonNode root, String outputName) throws IOException {
if ( ! root.isObject() || ! root.has("outputs")) {
return TensorType.empty;
}
@@ -123,7 +142,7 @@ public class OnnxModelProbe {
return out.toString();
}
- private static String callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException {
+ private static JsonNode callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException {
StringBuilder output = new StringBuilder();
ProcessBuilder processBuilder = new ProcessBuilder(binary, "--probe-types");
@@ -148,7 +167,16 @@ public class OnnxModelProbe {
throw new IllegalArgumentException("Error from '" + binary + "'. Return code: " + returnCode + ". " +
"Output: '" + output + "'");
}
- return output.toString();
+ return jsonParser.readTree(output.toString());
+ }
+
+ public record MemoryStats(long vmSize, long vmRss) {
+ static MemoryStats fromJson(JsonNode json) {
+ return new MemoryStats(json.get("vm_size").asLong(), json.get("vm_rss").asLong());
+ }
+ JsonNode toJson() {
+ return jsonParser.createObjectNode().put("vm_size", vmSize).put("vm_rss", vmRss);
+ }
}
}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java
index 245887a5d03..447614b8396 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java
@@ -4,6 +4,7 @@ package com.yahoo.vespa.model.application.validation;
import com.yahoo.config.ModelReference;
import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.model.NullConfigModelRegistry;
import com.yahoo.config.model.api.OnnxModelCost;
@@ -112,7 +113,7 @@ class JvmHeapSizeValidatorTest {
ModelCostDummy(long modelCost) { this.modelCost = modelCost; }
- @Override public Calculator newCalculator(DeployLogger logger) { return this; }
+ @Override public Calculator newCalculator(ApplicationPackage appPkg, DeployLogger logger) { return this; }
@Override public long aggregatedModelCostInBytes() { return totalCost.get(); }
@Override public void registerModel(ApplicationFile path) {}