aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java30
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java99
2 files changed, 129 insertions, 0 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java
new file mode 100644
index 00000000000..8c7f0db3bec
--- /dev/null
+++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java
@@ -0,0 +1,30 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.config.model.api;
+
+import com.yahoo.config.ModelReference;
+import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.application.api.DeployLogger;
+import com.yahoo.config.provision.ApplicationId;
+
+/**
+ * @author bjorncs
+ */
+public interface OnnxModelCost {
+
+ Calculator newCalculator(DeployLogger logger);
+
+ interface Calculator {
+ long aggregatedModelCostInBytes();
+ void registerModel(ApplicationFile path);
+ void registerModel(ModelReference ref);
+ }
+
+ static OnnxModelCost testInstance() {
+ return (__) -> new Calculator() {
+ @Override public long aggregatedModelCostInBytes() { return 0; }
+ @Override public void registerModel(ApplicationFile path) {}
+ @Override public void registerModel(ModelReference ref) {}
+ };
+ }
+}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java b/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java
new file mode 100644
index 00000000000..76733872882
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java
@@ -0,0 +1,99 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.vespa.model;
+
+import com.yahoo.config.ModelReference;
+import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.application.api.DeployLogger;
+import com.yahoo.config.model.api.OnnxModelCost;
+
+import java.io.IOException;
+import java.net.URI;
+import java.net.http.HttpClient;
+import java.net.http.HttpRequest;
+import java.net.http.HttpResponse;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.time.Duration;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.logging.Level;
+
+import static com.yahoo.yolean.Exceptions.uncheck;
+
+/**
+ * Aggregates estimated footprint of configured ONNX models.
+ *
+ * @author bjorncs
+ */
+public class DefaultOnnxModelCost implements OnnxModelCost {
+
+ @Override
+ public Calculator newCalculator(DeployLogger logger) {
+ return new CalculatorImpl(logger);
+ }
+
+ private static class CalculatorImpl implements Calculator {
+ private final DeployLogger log;
+
+ private final ConcurrentMap<String, Long> modelCost = new ConcurrentHashMap<>();
+
+ private CalculatorImpl(DeployLogger log) {
+ this.log = log;
+ }
+
+ @Override
+ public long aggregatedModelCostInBytes() {
+ return modelCost.values().stream().mapToLong(Long::longValue).sum();
+ }
+
+ @Override
+ public void registerModel(ApplicationFile f) {
+ String path = f.getPath().getRelative();
+ if (alreadyAnalyzed(path)) return;
+ log.log(Level.FINE, () -> "Register model '%s'".formatted(path));
+ deductJvmHeapSizeWithModelCost(f.exists() ? f.getSize() : 0, path);
+ }
+
+ @Override
+ public void registerModel(ModelReference ref) {
+ log.log(Level.FINE, () -> "Register model '%s'".formatted(ref.toString()));
+ if (ref.path().isPresent()) {
+ var path = Paths.get(ref.path().get().value());
+ var source = path.getFileName().toString();
+ if (alreadyAnalyzed(source)) return;
+ deductJvmHeapSizeWithModelCost(uncheck(() -> Files.exists(path) ? Files.size(path) : 0), source);
+ } else if (ref.url().isPresent()) deductJvmHeapSizeWithModelCost(URI.create(ref.url().get().value()));
+ else throw new IllegalStateException(ref.toString());
+ }
+
+ private void deductJvmHeapSizeWithModelCost(URI uri) {
+ if (alreadyAnalyzed(uri.toString())) return;
+ if (uri.getScheme().equals("http") || uri.getScheme().equals("https")) {
+ try {
+ var timeout = Duration.ofSeconds(3);
+ var httpClient = HttpClient.newBuilder().connectTimeout(timeout).build();
+ var request = HttpRequest.newBuilder(uri).timeout(timeout).method("HEAD", HttpRequest.BodyPublishers.noBody()).build();
+ var response = httpClient.send(request, HttpResponse.BodyHandlers.discarding());
+ var contentLength = response.headers().firstValue("Content-Length").orElse("0");
+ log.log(Level.FINE, () -> "Got content length '%s' for '%s'".formatted(contentLength, uri));
+ deductJvmHeapSizeWithModelCost(Long.parseLong(contentLength), uri.toString());
+ } catch (IllegalArgumentException | InterruptedException | IOException e) {
+ log.log(Level.INFO, () -> "Failed to get model size for '%s': %s".formatted(uri, e.getMessage()), e);
+ }
+ }
+ }
+
+ private void deductJvmHeapSizeWithModelCost(long size, String source) {
+ long fallbackModelSize = 1024*1024*1024;
+ long estimatedCost = Math.max(300*1024*1024, (long) (1.4D * (size > 0 ? size : fallbackModelSize) + 100*1024*1024));
+ log.log(Level.FINE, () ->
+ "Estimated %s footprint for model of size %s ('%s')".formatted(mb(estimatedCost), mb(size), source));
+ modelCost.put(source, estimatedCost);
+ }
+
+ private boolean alreadyAnalyzed(String source) { return modelCost.containsKey(source); }
+
+ private static String mb(long bytes) { return "%dMB".formatted(bytes / (1024*1024)); }
+ }
+}