diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java | 153 |
1 files changed, 0 insertions, 153 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java deleted file mode 100644 index 2ef81e3f1fa..00000000000 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java +++ /dev/null @@ -1,153 +0,0 @@ -package com.yahoo.vespa.model.ml; - -import com.fasterxml.jackson.core.JsonEncoding; -import com.fasterxml.jackson.core.JsonFactory; -import com.fasterxml.jackson.core.JsonGenerator; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.yahoo.config.application.api.ApplicationFile; -import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.io.IOUtils; -import com.yahoo.path.Path; -import com.yahoo.tensor.TensorType; - -import java.io.BufferedReader; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.nio.charset.StandardCharsets; -import java.util.Map; - -/** - * Defers to 'vespa-analyze-onnx-model' to determine the output type given - * a set of inputs. For situations with symbolic dimension sizes that can't - * easily be determined. - * - * @author lesters - */ -public class OnnxModelProbe { - - private static final String binary = "vespa-analyze-onnx-model"; - - static TensorType probeModel(ApplicationPackage app, Path modelPath, String outputName, Map<String, TensorType> inputTypes) { - TensorType outputType = TensorType.empty; - String contextKey = createContextKey(outputName, inputTypes); - - try { - // Check if output type has already been probed - outputType = readProbedOutputType(app, modelPath, contextKey); - - // Otherwise, run vespa-analyze-onnx-model if the model is available - if (outputType.equals(TensorType.empty) && app.getFile(modelPath).exists()) { - String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes); - String jsonOutput = callVespaAnalyzeOnnxModel(jsonInput); - outputType = outputTypeFromJson(jsonOutput, outputName); - if ( ! outputType.equals(TensorType.empty)) { - writeProbedOutputType(app, modelPath, contextKey, outputType); - } - } - - } catch (IOException | InterruptedException e) { - e.printStackTrace(System.err); - } - - return outputType; - } - - private static String createContextKey(String onnxName, Map<String, TensorType> inputTypes) { - StringBuilder key = new StringBuilder().append(onnxName).append(":"); - inputTypes.entrySet().stream().sorted(Map.Entry.comparingByKey()) - .forEachOrdered(e -> key.append(e.getKey()).append(":").append(e.getValue()).append(",")); - return key.substring(0, key.length()-1); - } - - private static Path probedOutputTypesPath(Path path) { - String fileName = OnnxModelInfo.asValidIdentifier(path.getRelative()) + ".probed_output_types"; - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName); - } - - static void writeProbedOutputType(ApplicationPackage app, Path modelPath, String output, - Map<String, TensorType> inputTypes, TensorType type) throws IOException { - writeProbedOutputType(app, modelPath, createContextKey(output, inputTypes), type); - } - - private static void writeProbedOutputType(ApplicationPackage app, Path modelPath, - String contextKey, TensorType type) throws IOException { - String path = app.getFileReference(probedOutputTypesPath(modelPath)).getAbsolutePath(); - IOUtils.writeFile(path, contextKey + "\t" + type + "\n", true); - } - - private static TensorType readProbedOutputType(ApplicationPackage app, Path modelPath, - String contextKey) throws IOException { - ApplicationFile file = app.getFile(probedOutputTypesPath(modelPath)); - if ( ! file.exists()) { - return TensorType.empty; - } - try (BufferedReader reader = new BufferedReader(file.createReader())) { - String line; - while (null != (line = reader.readLine())) { - String[] parts = line.split("\t"); - String key = parts[0]; - if (key.equals(contextKey)) { - return TensorType.fromSpec(parts[1]); - } - } - } - return TensorType.empty; - } - - private static TensorType outputTypeFromJson(String json, String outputName) throws IOException { - ObjectMapper m = new ObjectMapper(); - JsonNode root = m.readTree(json); - if ( ! root.isObject() || ! root.has("outputs")) { - return TensorType.empty; - } - JsonNode outputs = root.get("outputs"); - if ( ! outputs.has(outputName)) { - return TensorType.empty; - } - return TensorType.fromSpec(outputs.get(outputName).asText()); - } - - private static String createJsonInput(String modelPath, Map<String, TensorType> inputTypes) throws IOException { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8); - g.writeStartObject(); - g.writeStringField("model", modelPath); - g.writeObjectFieldStart("inputs"); - for (Map.Entry<String, TensorType> input : inputTypes.entrySet()) { - g.writeStringField(input.getKey(), input.getValue().toString()); - } - g.writeEndObject(); - g.writeEndObject(); - g.close(); - return out.toString(); - } - - private static String callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException { - ProcessBuilder processBuilder = new ProcessBuilder(binary, "--probe-types"); - processBuilder.redirectErrorStream(true); - StringBuilder output = new StringBuilder(); - Process process = processBuilder.start(); - - // Write json array to process stdin - OutputStream os = process.getOutputStream(); - os.write(jsonInput.getBytes(StandardCharsets.UTF_8)); - os.close(); - - // Read output from stdout/stderr - InputStream inputStream = process.getInputStream(); - while (true) { - int b = inputStream.read(); - if (b == -1) break; - output.append((char)b); - } - int returnCode = process.waitFor(); - if (returnCode != 0) { - throw new IllegalArgumentException("Error from '" + binary + "'. Return code: " + returnCode + ". Output:\n" + output); - } - return output.toString(); - } - -} |