diff options
-rw-r--r-- | config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java | 30 | ||||
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java | 99 |
2 files changed, 129 insertions, 0 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java new file mode 100644 index 00000000000..8c7f0db3bec --- /dev/null +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java @@ -0,0 +1,30 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.config.model.api; + +import com.yahoo.config.ModelReference; +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.DeployLogger; +import com.yahoo.config.provision.ApplicationId; + +/** + * @author bjorncs + */ +public interface OnnxModelCost { + + Calculator newCalculator(DeployLogger logger); + + interface Calculator { + long aggregatedModelCostInBytes(); + void registerModel(ApplicationFile path); + void registerModel(ModelReference ref); + } + + static OnnxModelCost testInstance() { + return (__) -> new Calculator() { + @Override public long aggregatedModelCostInBytes() { return 0; } + @Override public void registerModel(ApplicationFile path) {} + @Override public void registerModel(ModelReference ref) {} + }; + } +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java b/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java new file mode 100644 index 00000000000..76733872882 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java @@ -0,0 +1,99 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.model; + +import com.yahoo.config.ModelReference; +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.DeployLogger; +import com.yahoo.config.model.api.OnnxModelCost; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.time.Duration; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.logging.Level; + +import static com.yahoo.yolean.Exceptions.uncheck; + +/** + * Aggregates estimated footprint of configured ONNX models. + * + * @author bjorncs + */ +public class DefaultOnnxModelCost implements OnnxModelCost { + + @Override + public Calculator newCalculator(DeployLogger logger) { + return new CalculatorImpl(logger); + } + + private static class CalculatorImpl implements Calculator { + private final DeployLogger log; + + private final ConcurrentMap<String, Long> modelCost = new ConcurrentHashMap<>(); + + private CalculatorImpl(DeployLogger log) { + this.log = log; + } + + @Override + public long aggregatedModelCostInBytes() { + return modelCost.values().stream().mapToLong(Long::longValue).sum(); + } + + @Override + public void registerModel(ApplicationFile f) { + String path = f.getPath().getRelative(); + if (alreadyAnalyzed(path)) return; + log.log(Level.FINE, () -> "Register model '%s'".formatted(path)); + deductJvmHeapSizeWithModelCost(f.exists() ? f.getSize() : 0, path); + } + + @Override + public void registerModel(ModelReference ref) { + log.log(Level.FINE, () -> "Register model '%s'".formatted(ref.toString())); + if (ref.path().isPresent()) { + var path = Paths.get(ref.path().get().value()); + var source = path.getFileName().toString(); + if (alreadyAnalyzed(source)) return; + deductJvmHeapSizeWithModelCost(uncheck(() -> Files.exists(path) ? Files.size(path) : 0), source); + } else if (ref.url().isPresent()) deductJvmHeapSizeWithModelCost(URI.create(ref.url().get().value())); + else throw new IllegalStateException(ref.toString()); + } + + private void deductJvmHeapSizeWithModelCost(URI uri) { + if (alreadyAnalyzed(uri.toString())) return; + if (uri.getScheme().equals("http") || uri.getScheme().equals("https")) { + try { + var timeout = Duration.ofSeconds(3); + var httpClient = HttpClient.newBuilder().connectTimeout(timeout).build(); + var request = HttpRequest.newBuilder(uri).timeout(timeout).method("HEAD", HttpRequest.BodyPublishers.noBody()).build(); + var response = httpClient.send(request, HttpResponse.BodyHandlers.discarding()); + var contentLength = response.headers().firstValue("Content-Length").orElse("0"); + log.log(Level.FINE, () -> "Got content length '%s' for '%s'".formatted(contentLength, uri)); + deductJvmHeapSizeWithModelCost(Long.parseLong(contentLength), uri.toString()); + } catch (IllegalArgumentException | InterruptedException | IOException e) { + log.log(Level.INFO, () -> "Failed to get model size for '%s': %s".formatted(uri, e.getMessage()), e); + } + } + } + + private void deductJvmHeapSizeWithModelCost(long size, String source) { + long fallbackModelSize = 1024*1024*1024; + long estimatedCost = Math.max(300*1024*1024, (long) (1.4D * (size > 0 ? size : fallbackModelSize) + 100*1024*1024)); + log.log(Level.FINE, () -> + "Estimated %s footprint for model of size %s ('%s')".formatted(mb(estimatedCost), mb(size), source)); + modelCost.put(source, estimatedCost); + } + + private boolean alreadyAnalyzed(String source) { return modelCost.containsKey(source); } + + private static String mb(long bytes) { return "%dMB".formatted(bytes / (1024*1024)); } + } +} |