aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java
blob: 39a8e16fad58ca97e90682f15e5b9bffc9d4b20e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
package com.yahoo.vespa.model.ml;

import com.fasterxml.jackson.core.JsonEncoding;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.tensor.TensorType;

import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.Map;

/**
 * Defers to 'vespa-analyze-onnx-model' to determine the output type given
 * a set of inputs. For situations with symbolic dimension sizes that can't
 * easily be determined.
 *
 * @author lesters
 */
public class OnnxModelProbe {

    private static final String binary = "vespa-analyze-onnx-model";
    private static final ObjectMapper jsonParser = new ObjectMapper();

    static TensorType probeModel(ApplicationPackage app, Path modelPath, String outputName, Map<String, TensorType> inputTypes) {
        TensorType outputType = TensorType.empty;
        String contextKey = createContextKey(outputName, inputTypes);

        try {
            // Check if output type has already been probed
            outputType = readProbedOutputType(app, modelPath, contextKey);

            // Otherwise, run vespa-analyze-onnx-model if the model is available
            if (outputType.equals(TensorType.empty) && app.getFile(modelPath).exists()) {
                String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes);
                var jsonOutput = callVespaAnalyzeOnnxModel(jsonInput);
                outputType = outputTypeFromJson(jsonOutput, outputName);
                writeMemoryStats(app, modelPath, MemoryStats.fromJson(jsonOutput));
                if ( ! outputType.equals(TensorType.empty)) {
                    writeProbedOutputType(app, modelPath, contextKey, outputType);
                }
            }

        } catch (IllegalArgumentException | IOException | InterruptedException ignored) { }

        return outputType;
    }

    private static void writeMemoryStats(ApplicationPackage app, Path modelPath, MemoryStats memoryStats) throws IOException {
        String path = app.getFileReference(memoryStatsPath(modelPath)).getAbsolutePath();
        IOUtils.writeFile(path, memoryStats.toJson().toPrettyString(), false);
    }

    private static Path memoryStatsPath(Path modelPath) {
        var fileName = OnnxModelInfo.asValidIdentifier(modelPath.getRelative()) + ".memory_stats";
        return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName);
    }

    private static String createContextKey(String onnxName, Map<String, TensorType> inputTypes) {
        StringBuilder key = new StringBuilder().append(onnxName).append(":");
        inputTypes.entrySet().stream().sorted(Map.Entry.comparingByKey())
                .forEachOrdered(e -> key.append(e.getKey()).append(":").append(e.getValue()).append(","));
        return key.substring(0, key.length()-1);
    }

    private static Path probedOutputTypesPath(Path path) {
        String fileName = OnnxModelInfo.asValidIdentifier(path.getRelative()) + ".probed_output_types";
        return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName);
    }

    static void writeProbedOutputType(ApplicationPackage app, Path modelPath, String output,
                                      Map<String, TensorType> inputTypes, TensorType type) throws IOException {
        writeProbedOutputType(app, modelPath, createContextKey(output, inputTypes), type);
    }

    private static void writeProbedOutputType(ApplicationPackage app, Path modelPath,
                                              String contextKey, TensorType type) throws IOException {
        String path = app.getFileReference(probedOutputTypesPath(modelPath)).getAbsolutePath();
        IOUtils.writeFile(path, contextKey + "\t" + type + "\n", true);
    }

    private static TensorType readProbedOutputType(ApplicationPackage app, Path modelPath,
                                                   String contextKey) throws IOException {
        ApplicationFile file = app.getFile(probedOutputTypesPath(modelPath));
        if ( ! file.exists()) {
            return TensorType.empty;
        }
        try (BufferedReader reader = new BufferedReader(file.createReader())) {
            String line;
            while (null != (line = reader.readLine())) {
                String[] parts = line.split("\t");
                String key = parts[0];
                if (key.equals(contextKey)) {
                    return TensorType.fromSpec(parts[1]);
                }
            }
        }
        return TensorType.empty;
    }

    private static TensorType outputTypeFromJson(JsonNode root, String outputName) throws IOException {
        if ( ! root.isObject() || ! root.has("outputs")) {
            return TensorType.empty;
        }
        JsonNode outputs = root.get("outputs");
        if ( ! outputs.has(outputName)) {
            return TensorType.empty;
        }
        return TensorType.fromSpec(outputs.get(outputName).asText());
    }

    private static String createJsonInput(String modelPath, Map<String, TensorType> inputTypes) throws IOException {
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8);
        g.writeStartObject();
        g.writeStringField("model", modelPath);
        g.writeObjectFieldStart("inputs");
        for (Map.Entry<String, TensorType> input : inputTypes.entrySet()) {
            g.writeStringField(input.getKey(), input.getValue().toString());
        }
        g.writeEndObject();
        g.writeEndObject();
        g.close();
        return out.toString();
    }

    private static JsonNode callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException {
        StringBuilder output = new StringBuilder();

        ProcessBuilder processBuilder = new ProcessBuilder(binary, "--probe-types");
        processBuilder.redirectError(ProcessBuilder.Redirect.DISCARD);
        Process process = processBuilder.start();

        // Write json array to process stdin
        OutputStream os = process.getOutputStream();
        os.write(jsonInput.getBytes(StandardCharsets.UTF_8));
        os.close();

        // Read output from stdout
        InputStream inputStream = process.getInputStream();
        while (true) {
            int b = inputStream.read();
            if (b == -1) break;
            output.append((char)b);
        }

        int returnCode = process.waitFor();
        if (returnCode != 0) {
            throw new IllegalArgumentException("Error from '" + binary + "'. Return code: " + returnCode + ". " +
                                               "Output: '" + output + "'");
        }
        return jsonParser.readTree(output.toString());
    }

    public record MemoryStats(long vmSize, long vmRss) {
        static MemoryStats fromJson(JsonNode json) {
            return new MemoryStats(json.get("vm_size").asLong(), json.get("vm_rss").asLong());
        }
        JsonNode toJson() {
            return jsonParser.createObjectNode().put("vm_size", vmSize).put("vm_rss", vmRss);
        }
    }

}