diff options
author | Lester Solbakken <lesters@oath.com> | 2022-02-23 14:45:31 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2022-02-23 14:45:31 +0100 |
commit | f59b5aa627018e6574249488c802fbccd8360688 (patch) | |
tree | c027015867cc52dbd2d2dccffc0ea08f7a274e14 /config-model/src/main/java | |
parent | d4d0db00c51be4b195721b331c3b958bf07b882d (diff) |
Add probing of ONNX models for type resolving in Java
Diffstat (limited to 'config-model/src/main/java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java | 38 | ||||
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java | 151 |
2 files changed, 183 insertions, 6 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java index 88139de7888..2742dc59fcd 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java @@ -36,13 +36,15 @@ import java.util.stream.Collectors; */ public class OnnxModelInfo { + private final ApplicationPackage app; private final String modelPath; private final String defaultOutput; private final Map<String, OnnxTypeInfo> inputs; private final Map<String, OnnxTypeInfo> outputs; private final Map<String, TensorType> vespaTypes = new HashMap<>(); - private OnnxModelInfo(String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) { + private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) { + this.app = app; this.modelPath = path; this.inputs = Collections.unmodifiableMap(inputs); this.outputs = Collections.unmodifiableMap(outputs); @@ -79,7 +81,15 @@ public class OnnxModelInfo { Set<Long> unboundSizes = new HashSet<>(); Map<String, Long> symbolicSizes = new HashMap<>(); resolveUnknownDimensionSizes(inputTypes, symbolicSizes, unboundSizes); - return onnxTypeInfo.toVespaTensorType(symbolicSizes, unboundSizes); + + TensorType type = TensorType.empty; + if (inputTypes.size() > 0 && onnxTypeInfo.needModelProbe(symbolicSizes)) { + type = OnnxModelProbe.probeModel(app, Path.fromString(modelPath), onnxName, inputTypes); + } + if (type.equals(TensorType.empty)) { + type = onnxTypeInfo.toVespaTensorType(symbolicSizes, unboundSizes); + } + return type; } return vespaTypes.computeIfAbsent(onnxName, v -> onnxTypeInfo.toVespaTensorType()); } @@ -150,7 +160,8 @@ public class OnnxModelInfo { Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); String json = onnxModelToJson(model, path); storeGeneratedInfo(json, path, app); - return jsonToModelInfo(json); + return jsonToModelInfo(json, app); + } catch (IOException e) { throw new IllegalArgumentException("Unable to parse ONNX model", e); } @@ -159,7 +170,7 @@ public class OnnxModelInfo { static private OnnxModelInfo loadFromGeneratedInfo(Path path, ApplicationPackage app) { try { String json = readGeneratedInfo(path, app); - return jsonToModelInfo(json); + return jsonToModelInfo(json, app); } catch (IOException e) { throw new IllegalArgumentException("Unable to parse ONNX model", e); } @@ -202,7 +213,7 @@ public class OnnxModelInfo { return out.toString(); } - static public OnnxModelInfo jsonToModelInfo(String json) throws IOException { + static public OnnxModelInfo jsonToModelInfo(String json, ApplicationPackage app) throws IOException { ObjectMapper m = new ObjectMapper(); JsonNode root = m.readTree(json); Map<String, OnnxTypeInfo> inputs = new HashMap<>(); @@ -222,7 +233,7 @@ public class OnnxModelInfo { if (root.get("outputs").has(0)) { defaultOutput = root.get("outputs").get(0).get("name").textValue(); } - return new OnnxModelInfo(path, inputs, outputs, defaultOutput); + return new OnnxModelInfo(app, path, inputs, outputs, defaultOutput); } static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException { @@ -353,6 +364,21 @@ public class OnnxModelInfo { return builder.build(); } + boolean needModelProbe(Map<String, Long> symbolicSizes) { + for (OnnxDimensionInfo onnxDimension : dimensions) { + if (onnxDimension.hasSymbolicName()) { + if (symbolicSizes == null) + return true; + if ( ! symbolicSizes.containsKey(onnxDimension.getSymbolicName())) { + return true; + } + } else if (onnxDimension.getSize() == 0) { + return true; + } + } + return false; + } + @Override public String toString() { return "(" + valueType.id() + ")" + diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java new file mode 100644 index 00000000000..3a77a349af3 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java @@ -0,0 +1,151 @@ +package com.yahoo.vespa.model.ml; + +import com.fasterxml.jackson.core.JsonEncoding; +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.tensor.TensorType; + +import java.io.BufferedReader; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.util.Map; + +/** + * Defers to 'vespa-analyze-onnx-model' to determine the output type given + * a set of inputs. For situations with symbolic dimension sizes that can't + * easily be determined. + * + * @author lesters + */ +public class OnnxModelProbe { + + private static final String binary = "vespa-analyze-onnx-model"; + + static TensorType probeModel(ApplicationPackage app, Path modelPath, String outputName, Map<String, TensorType> inputTypes) { + TensorType outputType = TensorType.empty; + String contextKey = createContextKey(outputName, inputTypes); + + try { + // Check if output type has already been probed + outputType = readProbedOutputType(app, modelPath, contextKey); + + // Otherwise, run vespa-analyze-onnx-model if the model is available + if (outputType.equals(TensorType.empty) && app.getFile(modelPath).exists()) { + String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes); + String jsonOutput = callVespaAnalyzeOnnxModel(jsonInput); + outputType = outputTypeFromJson(jsonOutput, outputName); + writeProbedOutputType(app, modelPath, contextKey, outputType); + } + + } catch (IOException | InterruptedException e) { + throw new IllegalArgumentException("Unable to probe ONNX model", e); + } + + return outputType; + } + + private static String createContextKey(String onnxName, Map<String, TensorType> inputTypes) { + StringBuilder key = new StringBuilder().append(onnxName).append(":"); + inputTypes.entrySet().stream().sorted(Map.Entry.comparingByKey()) + .forEachOrdered(e -> key.append(e.getKey()).append(":").append(e.getValue()).append(",")); + return key.substring(0, key.length()-1); + } + + private static Path probedOutputTypesPath(Path path) { + String fileName = OnnxModelInfo.asValidIdentifier(path.getRelative()) + ".probed_output_types"; + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName); + } + + static void writeProbedOutputType(ApplicationPackage app, Path modelPath, String output, + Map<String, TensorType> inputTypes, TensorType type) throws IOException { + writeProbedOutputType(app, modelPath, createContextKey(output, inputTypes), type); + } + + private static void writeProbedOutputType(ApplicationPackage app, Path modelPath, + String contextKey, TensorType type) throws IOException { + String path = app.getFileReference(probedOutputTypesPath(modelPath)).getAbsolutePath(); + IOUtils.writeFile(path, contextKey + "\t" + type + "\n", true); + } + + private static TensorType readProbedOutputType(ApplicationPackage app, Path modelPath, + String contextKey) throws IOException { + ApplicationFile file = app.getFile(probedOutputTypesPath(modelPath)); + if ( ! file.exists()) { + return TensorType.empty; + } + try (BufferedReader reader = new BufferedReader(file.createReader())) { + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String key = parts[0]; + if (key.equals(contextKey)) { + return TensorType.fromSpec(parts[1]); + } + } + } + return TensorType.empty; + } + + private static TensorType outputTypeFromJson(String json, String outputName) throws IOException { + ObjectMapper m = new ObjectMapper(); + JsonNode root = m.readTree(json); + if ( ! root.isObject() || ! root.has("outputs")) { + return TensorType.empty; + } + JsonNode outputs = root.get("outputs"); + if ( ! outputs.has(outputName)) { + return TensorType.empty; + } + return TensorType.fromSpec(outputs.get(outputName).asText()); + } + + private static String createJsonInput(String modelPath, Map<String, TensorType> inputTypes) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8); + g.writeStartObject(); + g.writeStringField("model", modelPath); + g.writeObjectFieldStart("inputs"); + for (Map.Entry<String, TensorType> input : inputTypes.entrySet()) { + g.writeStringField(input.getKey(), input.getValue().toString()); + } + g.writeEndObject(); + g.writeEndObject(); + g.close(); + return out.toString(); + } + + private static String callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException { + ProcessBuilder processBuilder = new ProcessBuilder(binary, "--probe-types"); + processBuilder.redirectErrorStream(true); + StringBuilder output = new StringBuilder(); + Process process = processBuilder.start(); // kaster en exception dersom executable ikke finnes + + // Write json array to process stdin + OutputStream os = process.getOutputStream(); + os.write(jsonInput.getBytes(StandardCharsets.UTF_8)); + os.close(); + + // Read output from stdout/stderr + InputStream inputStream = process.getInputStream(); + while (true) { + int b = inputStream.read(); + if (b == -1) break; + output.append((char)b); + } + int returnCode = process.waitFor(); + if (returnCode != 0) { + throw new IllegalArgumentException("Error from '" + binary + "'. Return code: " + returnCode + ". Output:\n" + output); + } + return output.toString(); + } + +} |