aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java
blob: fddf84093767b9a130b03d34574d3546335a6a54 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

package com.yahoo.vespa.model;

import com.yahoo.config.ModelReference;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.model.api.OnnxModelCost;
import com.yahoo.vespa.model.ml.OnnxModelProbe;

import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.time.Duration;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.logging.Level;

import static com.yahoo.yolean.Exceptions.uncheck;

/**
 * Aggregates estimated footprint of configured ONNX models.
 *
 * @author bjorncs
 */
public class DefaultOnnxModelCost implements OnnxModelCost {

    @Override
    public Calculator newCalculator(ApplicationPackage appPkg, DeployLogger logger) {
        return new CalculatorImpl(appPkg, logger);
    }

    private static class CalculatorImpl implements Calculator {
        private final DeployLogger log;
        private final ApplicationPackage appPkg;

        private final ConcurrentMap<String, Long> modelCost = new ConcurrentHashMap<>();

        private CalculatorImpl(ApplicationPackage appPkg, DeployLogger log) {
            this.appPkg = appPkg;
            this.log = log;
        }

        @Override
        public long aggregatedModelCostInBytes() {
            return modelCost.values().stream().mapToLong(Long::longValue).sum();
        }

        @Override
        public void registerModel(ApplicationFile f) {
            String path = f.getPath().getRelative();
            if (alreadyAnalyzed(path)) return;
            log.log(Level.FINE, () -> "Register model '%s'".formatted(path));
            if (f.exists() && appPkg != null) {
                var memoryStats = OnnxModelProbe.probeMemoryStats(appPkg, f.getPath()).orElse(null);
                if (memoryStats != null) {
                    log.log(Level.FINE, () -> "Register model '%s' with memory stats: %s".formatted(path, memoryStats));
                    deductJvmHeapSizeWithModelCost(f.getSize(), memoryStats, path);
                } else {
                    deductJvmHeapSizeWithModelCost(f.getSize(), path);
                }
            } else {
                deductJvmHeapSizeWithModelCost(0, path);
            }
        }

        @Override
        public void registerModel(ModelReference ref) {
            log.log(Level.FINE, () -> "Register model '%s'".formatted(ref.toString()));
            if (ref.path().isPresent()) {
                var path = Paths.get(ref.path().get().value());
                var source = path.getFileName().toString();
                if (alreadyAnalyzed(source)) return;
                deductJvmHeapSizeWithModelCost(uncheck(() -> Files.exists(path) ? Files.size(path) : 0), source);
            } else if (ref.url().isPresent()) deductJvmHeapSizeWithModelCost(URI.create(ref.url().get().value()));
            else throw new IllegalStateException(ref.toString());
        }

        private void deductJvmHeapSizeWithModelCost(URI uri) {
            if (alreadyAnalyzed(uri.toString())) return;
            if (uri.getScheme().equals("http") || uri.getScheme().equals("https")) {
                try {
                    var timeout = Duration.ofSeconds(3);
                    var httpClient = HttpClient.newBuilder().connectTimeout(timeout).build();
                    var request = HttpRequest.newBuilder(uri).timeout(timeout).method("HEAD", HttpRequest.BodyPublishers.noBody()).build();
                    var response = httpClient.send(request, HttpResponse.BodyHandlers.discarding());
                    var contentLength = response.headers().firstValue("Content-Length").orElse("0");
                    log.log(Level.FINE, () -> "Got content length '%s' for '%s'".formatted(contentLength, uri));
                    deductJvmHeapSizeWithModelCost(Long.parseLong(contentLength), uri.toString());
                } catch (IllegalArgumentException | InterruptedException | IOException e) {
                    log.log(Level.INFO, () -> "Failed to get model size for '%s': %s".formatted(uri, e.getMessage()), e);
                }
            }
        }

        private void deductJvmHeapSizeWithModelCost(long size, String source) {
            long fallbackModelSize = 1024*1024*1024;
            long estimatedCost = Math.max(300*1024*1024, (long) (1.4D * (size > 0 ? size : fallbackModelSize) + 100*1024*1024));
            log.log(Level.FINE, () ->
                    "Estimated %s footprint for model of size %s ('%s')".formatted(mb(estimatedCost), mb(size), source));
            modelCost.put(source, estimatedCost);
        }

        private void deductJvmHeapSizeWithModelCost(long size, OnnxModelProbe.MemoryStats stats, String source) {
            long estimatedCost = (long)(1.1D * stats.vmSize());
            log.log(Level.FINE, () ->
                    "Estimated %s footprint for model of size %s ('%s')".formatted(mb(estimatedCost), mb(size), source));
            modelCost.put(source, estimatedCost);
        }

        private boolean alreadyAnalyzed(String source) { return modelCost.containsKey(source); }

        private static String mb(long bytes) { return "%dMB".formatted(bytes / (1024*1024)); }
    }
}