summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java
blob: 2ef81e3f1faf7acf349379f844b8470bae2a2768 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package com.yahoo.vespa.model.ml;

import com.fasterxml.jackson.core.JsonEncoding;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.tensor.TensorType;

import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.Map;

/**
 * Defers to 'vespa-analyze-onnx-model' to determine the output type given
 * a set of inputs. For situations with symbolic dimension sizes that can't
 * easily be determined.
 *
 * @author lesters
 */
public class OnnxModelProbe {

    private static final String binary = "vespa-analyze-onnx-model";

    static TensorType probeModel(ApplicationPackage app, Path modelPath, String outputName, Map<String, TensorType> inputTypes) {
        TensorType outputType = TensorType.empty;
        String contextKey = createContextKey(outputName, inputTypes);

        try {
            // Check if output type has already been probed
            outputType = readProbedOutputType(app, modelPath, contextKey);

            // Otherwise, run vespa-analyze-onnx-model if the model is available
            if (outputType.equals(TensorType.empty) && app.getFile(modelPath).exists()) {
                String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes);
                String jsonOutput = callVespaAnalyzeOnnxModel(jsonInput);
                outputType = outputTypeFromJson(jsonOutput, outputName);
                if ( ! outputType.equals(TensorType.empty)) {
                    writeProbedOutputType(app, modelPath, contextKey, outputType);
                }
            }

        } catch (IOException | InterruptedException e) {
            e.printStackTrace(System.err);
        }

        return outputType;
    }

    private static String createContextKey(String onnxName, Map<String, TensorType> inputTypes) {
        StringBuilder key = new StringBuilder().append(onnxName).append(":");
        inputTypes.entrySet().stream().sorted(Map.Entry.comparingByKey())
                .forEachOrdered(e -> key.append(e.getKey()).append(":").append(e.getValue()).append(","));
        return key.substring(0, key.length()-1);
    }

    private static Path probedOutputTypesPath(Path path) {
        String fileName = OnnxModelInfo.asValidIdentifier(path.getRelative()) + ".probed_output_types";
        return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName);
    }

    static void writeProbedOutputType(ApplicationPackage app, Path modelPath, String output,
                                      Map<String, TensorType> inputTypes, TensorType type) throws IOException {
        writeProbedOutputType(app, modelPath, createContextKey(output, inputTypes), type);
    }

    private static void writeProbedOutputType(ApplicationPackage app, Path modelPath,
                                              String contextKey, TensorType type) throws IOException {
        String path = app.getFileReference(probedOutputTypesPath(modelPath)).getAbsolutePath();
        IOUtils.writeFile(path, contextKey + "\t" + type + "\n", true);
    }

    private static TensorType readProbedOutputType(ApplicationPackage app, Path modelPath,
                                                   String contextKey) throws IOException {
        ApplicationFile file = app.getFile(probedOutputTypesPath(modelPath));
        if ( ! file.exists()) {
            return TensorType.empty;
        }
        try (BufferedReader reader = new BufferedReader(file.createReader())) {
            String line;
            while (null != (line = reader.readLine())) {
                String[] parts = line.split("\t");
                String key = parts[0];
                if (key.equals(contextKey)) {
                    return TensorType.fromSpec(parts[1]);
                }
            }
        }
        return TensorType.empty;
    }

    private static TensorType outputTypeFromJson(String json, String outputName) throws IOException {
        ObjectMapper m = new ObjectMapper();
        JsonNode root = m.readTree(json);
        if ( ! root.isObject() || ! root.has("outputs")) {
            return TensorType.empty;
        }
        JsonNode outputs = root.get("outputs");
        if ( ! outputs.has(outputName)) {
            return TensorType.empty;
        }
        return TensorType.fromSpec(outputs.get(outputName).asText());
    }

    private static String createJsonInput(String modelPath, Map<String, TensorType> inputTypes) throws IOException {
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8);
        g.writeStartObject();
        g.writeStringField("model", modelPath);
        g.writeObjectFieldStart("inputs");
        for (Map.Entry<String, TensorType> input : inputTypes.entrySet()) {
            g.writeStringField(input.getKey(), input.getValue().toString());
        }
        g.writeEndObject();
        g.writeEndObject();
        g.close();
        return out.toString();
    }

    private static String callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException {
        ProcessBuilder processBuilder = new ProcessBuilder(binary, "--probe-types");
        processBuilder.redirectErrorStream(true);
        StringBuilder output = new StringBuilder();
        Process process = processBuilder.start();

        // Write json array to process stdin
        OutputStream os = process.getOutputStream();
        os.write(jsonInput.getBytes(StandardCharsets.UTF_8));
        os.close();

        // Read output from stdout/stderr
        InputStream inputStream = process.getInputStream();
        while (true) {
            int b = inputStream.read();
            if (b == -1) break;
            output.append((char)b);
        }
        int returnCode = process.waitFor();
        if (returnCode != 0) {
            throw new IllegalArgumentException("Error from '" + binary + "'. Return code: " + returnCode + ". Output:\n" + output);
        }
        return output.toString();
    }

}