diff options
5 files changed, 273 insertions, 6 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java index 88139de7888..2742dc59fcd 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java @@ -36,13 +36,15 @@ import java.util.stream.Collectors; */ public class OnnxModelInfo { + private final ApplicationPackage app; private final String modelPath; private final String defaultOutput; private final Map<String, OnnxTypeInfo> inputs; private final Map<String, OnnxTypeInfo> outputs; private final Map<String, TensorType> vespaTypes = new HashMap<>(); - private OnnxModelInfo(String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) { + private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) { + this.app = app; this.modelPath = path; this.inputs = Collections.unmodifiableMap(inputs); this.outputs = Collections.unmodifiableMap(outputs); @@ -79,7 +81,15 @@ public class OnnxModelInfo { Set<Long> unboundSizes = new HashSet<>(); Map<String, Long> symbolicSizes = new HashMap<>(); resolveUnknownDimensionSizes(inputTypes, symbolicSizes, unboundSizes); - return onnxTypeInfo.toVespaTensorType(symbolicSizes, unboundSizes); + + TensorType type = TensorType.empty; + if (inputTypes.size() > 0 && onnxTypeInfo.needModelProbe(symbolicSizes)) { + type = OnnxModelProbe.probeModel(app, Path.fromString(modelPath), onnxName, inputTypes); + } + if (type.equals(TensorType.empty)) { + type = onnxTypeInfo.toVespaTensorType(symbolicSizes, unboundSizes); + } + return type; } return vespaTypes.computeIfAbsent(onnxName, v -> onnxTypeInfo.toVespaTensorType()); } @@ -150,7 +160,8 @@ public class OnnxModelInfo { Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); String json = onnxModelToJson(model, path); storeGeneratedInfo(json, path, app); - return jsonToModelInfo(json); + return jsonToModelInfo(json, app); + } catch (IOException e) { throw new IllegalArgumentException("Unable to parse ONNX model", e); } @@ -159,7 +170,7 @@ public class OnnxModelInfo { static private OnnxModelInfo loadFromGeneratedInfo(Path path, ApplicationPackage app) { try { String json = readGeneratedInfo(path, app); - return jsonToModelInfo(json); + return jsonToModelInfo(json, app); } catch (IOException e) { throw new IllegalArgumentException("Unable to parse ONNX model", e); } @@ -202,7 +213,7 @@ public class OnnxModelInfo { return out.toString(); } - static public OnnxModelInfo jsonToModelInfo(String json) throws IOException { + static public OnnxModelInfo jsonToModelInfo(String json, ApplicationPackage app) throws IOException { ObjectMapper m = new ObjectMapper(); JsonNode root = m.readTree(json); Map<String, OnnxTypeInfo> inputs = new HashMap<>(); @@ -222,7 +233,7 @@ public class OnnxModelInfo { if (root.get("outputs").has(0)) { defaultOutput = root.get("outputs").get(0).get("name").textValue(); } - return new OnnxModelInfo(path, inputs, outputs, defaultOutput); + return new OnnxModelInfo(app, path, inputs, outputs, defaultOutput); } static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException { @@ -353,6 +364,21 @@ public class OnnxModelInfo { return builder.build(); } + boolean needModelProbe(Map<String, Long> symbolicSizes) { + for (OnnxDimensionInfo onnxDimension : dimensions) { + if (onnxDimension.hasSymbolicName()) { + if (symbolicSizes == null) + return true; + if ( ! symbolicSizes.containsKey(onnxDimension.getSymbolicName())) { + return true; + } + } else if (onnxDimension.getSize() == 0) { + return true; + } + } + return false; + } + @Override public String toString() { return "(" + valueType.id() + ")" + diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java new file mode 100644 index 00000000000..3a77a349af3 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java @@ -0,0 +1,151 @@ +package com.yahoo.vespa.model.ml; + +import com.fasterxml.jackson.core.JsonEncoding; +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.tensor.TensorType; + +import java.io.BufferedReader; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.util.Map; + +/** + * Defers to 'vespa-analyze-onnx-model' to determine the output type given + * a set of inputs. For situations with symbolic dimension sizes that can't + * easily be determined. + * + * @author lesters + */ +public class OnnxModelProbe { + + private static final String binary = "vespa-analyze-onnx-model"; + + static TensorType probeModel(ApplicationPackage app, Path modelPath, String outputName, Map<String, TensorType> inputTypes) { + TensorType outputType = TensorType.empty; + String contextKey = createContextKey(outputName, inputTypes); + + try { + // Check if output type has already been probed + outputType = readProbedOutputType(app, modelPath, contextKey); + + // Otherwise, run vespa-analyze-onnx-model if the model is available + if (outputType.equals(TensorType.empty) && app.getFile(modelPath).exists()) { + String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes); + String jsonOutput = callVespaAnalyzeOnnxModel(jsonInput); + outputType = outputTypeFromJson(jsonOutput, outputName); + writeProbedOutputType(app, modelPath, contextKey, outputType); + } + + } catch (IOException | InterruptedException e) { + throw new IllegalArgumentException("Unable to probe ONNX model", e); + } + + return outputType; + } + + private static String createContextKey(String onnxName, Map<String, TensorType> inputTypes) { + StringBuilder key = new StringBuilder().append(onnxName).append(":"); + inputTypes.entrySet().stream().sorted(Map.Entry.comparingByKey()) + .forEachOrdered(e -> key.append(e.getKey()).append(":").append(e.getValue()).append(",")); + return key.substring(0, key.length()-1); + } + + private static Path probedOutputTypesPath(Path path) { + String fileName = OnnxModelInfo.asValidIdentifier(path.getRelative()) + ".probed_output_types"; + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName); + } + + static void writeProbedOutputType(ApplicationPackage app, Path modelPath, String output, + Map<String, TensorType> inputTypes, TensorType type) throws IOException { + writeProbedOutputType(app, modelPath, createContextKey(output, inputTypes), type); + } + + private static void writeProbedOutputType(ApplicationPackage app, Path modelPath, + String contextKey, TensorType type) throws IOException { + String path = app.getFileReference(probedOutputTypesPath(modelPath)).getAbsolutePath(); + IOUtils.writeFile(path, contextKey + "\t" + type + "\n", true); + } + + private static TensorType readProbedOutputType(ApplicationPackage app, Path modelPath, + String contextKey) throws IOException { + ApplicationFile file = app.getFile(probedOutputTypesPath(modelPath)); + if ( ! file.exists()) { + return TensorType.empty; + } + try (BufferedReader reader = new BufferedReader(file.createReader())) { + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String key = parts[0]; + if (key.equals(contextKey)) { + return TensorType.fromSpec(parts[1]); + } + } + } + return TensorType.empty; + } + + private static TensorType outputTypeFromJson(String json, String outputName) throws IOException { + ObjectMapper m = new ObjectMapper(); + JsonNode root = m.readTree(json); + if ( ! root.isObject() || ! root.has("outputs")) { + return TensorType.empty; + } + JsonNode outputs = root.get("outputs"); + if ( ! outputs.has(outputName)) { + return TensorType.empty; + } + return TensorType.fromSpec(outputs.get(outputName).asText()); + } + + private static String createJsonInput(String modelPath, Map<String, TensorType> inputTypes) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8); + g.writeStartObject(); + g.writeStringField("model", modelPath); + g.writeObjectFieldStart("inputs"); + for (Map.Entry<String, TensorType> input : inputTypes.entrySet()) { + g.writeStringField(input.getKey(), input.getValue().toString()); + } + g.writeEndObject(); + g.writeEndObject(); + g.close(); + return out.toString(); + } + + private static String callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException { + ProcessBuilder processBuilder = new ProcessBuilder(binary, "--probe-types"); + processBuilder.redirectErrorStream(true); + StringBuilder output = new StringBuilder(); + Process process = processBuilder.start(); // kaster en exception dersom executable ikke finnes + + // Write json array to process stdin + OutputStream os = process.getOutputStream(); + os.write(jsonInput.getBytes(StandardCharsets.UTF_8)); + os.close(); + + // Read output from stdout/stderr + InputStream inputStream = process.getInputStream(); + while (true) { + int b = inputStream.read(); + if (b == -1) break; + output.append((char)b); + } + int returnCode = process.waitFor(); + if (returnCode != 0) { + throw new IllegalArgumentException("Error from '" + binary + "'. Return code: " + returnCode + ". Output:\n" + output); + } + return output.toString(); + } + +} diff --git a/config-model/src/test/cfg/application/onnx_probe/files/create_dynamic_model.py b/config-model/src/test/cfg/application/onnx_probe/files/create_dynamic_model.py new file mode 100755 index 00000000000..b493e394ee4 --- /dev/null +++ b/config-model/src/test/cfg/application/onnx_probe/files/create_dynamic_model.py @@ -0,0 +1,19 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import onnx +import numpy as np +from onnx import helper, TensorProto + +INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, ["batch", 2]) +INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, ["batch", 2]) +OUTPUT = helper.make_tensor_value_info('out', TensorProto.FLOAT, ["batch", "dim1", "dim2"]) + +SHAPE = helper.make_tensor('shape', TensorProto.INT64, dims=[3], vals=np.array([1,2,2]).astype(np.int64)) + +nodes = [ + helper.make_node('Concat', ['input1', 'input2'], ['concat'], axis=1), + helper.make_node('Reshape', ['concat', 'shape'], ['out']), +] +graph_def = helper.make_graph(nodes, 'simple_scoring', [INPUT_1, INPUT_2], [OUTPUT], [SHAPE]) +model_def = helper.make_model(graph_def, producer_name='create_dynamic_model.py', opset_imports=[onnx.OperatorSetIdProto(version=12)]) +onnx.save(model_def, 'dynamic_model_2.onnx') diff --git a/config-model/src/test/cfg/application/onnx_probe/files/dynamic_model.onnx b/config-model/src/test/cfg/application/onnx_probe/files/dynamic_model.onnx new file mode 100644 index 00000000000..28c600e2a09 --- /dev/null +++ b/config-model/src/test/cfg/application/onnx_probe/files/dynamic_model.onnx @@ -0,0 +1,21 @@ +create_dynamic_model_2.py:Ö +- +input1 +input2concat"Concat* +axis + +concat +shapeout"Reshapesimple_scoring*:BshapeZ +input1 +
+batch +Z +input2 +
+batch +b& +out + +batch +dim1 +dim2B
\ No newline at end of file diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java new file mode 100644 index 00000000000..55e2f6f018c --- /dev/null +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java @@ -0,0 +1,50 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.ml; + +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.application.provider.FilesApplicationPackage; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.tensor.TensorType; +import org.junit.Test; + +import java.io.IOException; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +public class OnnxModelProbeTest { + + @Test + public void testProbedOutputTypes() throws IOException { + + Path appDir = Path.fromString("src/test/cfg/application/onnx_probe"); + Path storedAppDir = appDir.append("copy"); + try { + FilesApplicationPackage app = FilesApplicationPackage.fromFile(appDir.toFile()); + Path modelPath = Path.fromString("files/dynamic_model.onnx"); + String output = "out"; + Map<String, TensorType> inputTypes = Map.of( + "input1", TensorType.fromSpec("tensor<float>(d0[1],d1[2])"), + "input2", TensorType.fromSpec("tensor<float>(d0[1],d1[2])")); + TensorType expected = TensorType.fromSpec("tensor<float>(d0[1],d1[2],d2[2])"); + + // Can't test model probing directly as 'vespa-analyze-onnx-model' is not available + + OnnxModelProbe.writeProbedOutputType(app, modelPath, output, inputTypes, expected); + + // Test loading from generated info + storedAppDir.toFile().mkdirs(); + IOUtils.copyDirectory(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + app = FilesApplicationPackage.fromFile(storedAppDir.toFile()); + TensorType outputType = OnnxModelProbe.probeModel(app, modelPath, output, inputTypes); + assertEquals(outputType, expected); + + } finally { + IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + IOUtils.recursiveDeleteDir(storedAppDir.toFile()); + } + } + +} |