aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java38
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java151
-rwxr-xr-xconfig-model/src/test/cfg/application/onnx_probe/files/create_dynamic_model.py19
-rw-r--r--config-model/src/test/cfg/application/onnx_probe/files/dynamic_model.onnx21
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java50
5 files changed, 273 insertions, 6 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
index 88139de7888..2742dc59fcd 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
@@ -36,13 +36,15 @@ import java.util.stream.Collectors;
*/
public class OnnxModelInfo {
+ private final ApplicationPackage app;
private final String modelPath;
private final String defaultOutput;
private final Map<String, OnnxTypeInfo> inputs;
private final Map<String, OnnxTypeInfo> outputs;
private final Map<String, TensorType> vespaTypes = new HashMap<>();
- private OnnxModelInfo(String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) {
+ private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) {
+ this.app = app;
this.modelPath = path;
this.inputs = Collections.unmodifiableMap(inputs);
this.outputs = Collections.unmodifiableMap(outputs);
@@ -79,7 +81,15 @@ public class OnnxModelInfo {
Set<Long> unboundSizes = new HashSet<>();
Map<String, Long> symbolicSizes = new HashMap<>();
resolveUnknownDimensionSizes(inputTypes, symbolicSizes, unboundSizes);
- return onnxTypeInfo.toVespaTensorType(symbolicSizes, unboundSizes);
+
+ TensorType type = TensorType.empty;
+ if (inputTypes.size() > 0 && onnxTypeInfo.needModelProbe(symbolicSizes)) {
+ type = OnnxModelProbe.probeModel(app, Path.fromString(modelPath), onnxName, inputTypes);
+ }
+ if (type.equals(TensorType.empty)) {
+ type = onnxTypeInfo.toVespaTensorType(symbolicSizes, unboundSizes);
+ }
+ return type;
}
return vespaTypes.computeIfAbsent(onnxName, v -> onnxTypeInfo.toVespaTensorType());
}
@@ -150,7 +160,8 @@ public class OnnxModelInfo {
Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
String json = onnxModelToJson(model, path);
storeGeneratedInfo(json, path, app);
- return jsonToModelInfo(json);
+ return jsonToModelInfo(json, app);
+
} catch (IOException e) {
throw new IllegalArgumentException("Unable to parse ONNX model", e);
}
@@ -159,7 +170,7 @@ public class OnnxModelInfo {
static private OnnxModelInfo loadFromGeneratedInfo(Path path, ApplicationPackage app) {
try {
String json = readGeneratedInfo(path, app);
- return jsonToModelInfo(json);
+ return jsonToModelInfo(json, app);
} catch (IOException e) {
throw new IllegalArgumentException("Unable to parse ONNX model", e);
}
@@ -202,7 +213,7 @@ public class OnnxModelInfo {
return out.toString();
}
- static public OnnxModelInfo jsonToModelInfo(String json) throws IOException {
+ static public OnnxModelInfo jsonToModelInfo(String json, ApplicationPackage app) throws IOException {
ObjectMapper m = new ObjectMapper();
JsonNode root = m.readTree(json);
Map<String, OnnxTypeInfo> inputs = new HashMap<>();
@@ -222,7 +233,7 @@ public class OnnxModelInfo {
if (root.get("outputs").has(0)) {
defaultOutput = root.get("outputs").get(0).get("name").textValue();
}
- return new OnnxModelInfo(path, inputs, outputs, defaultOutput);
+ return new OnnxModelInfo(app, path, inputs, outputs, defaultOutput);
}
static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException {
@@ -353,6 +364,21 @@ public class OnnxModelInfo {
return builder.build();
}
+ boolean needModelProbe(Map<String, Long> symbolicSizes) {
+ for (OnnxDimensionInfo onnxDimension : dimensions) {
+ if (onnxDimension.hasSymbolicName()) {
+ if (symbolicSizes == null)
+ return true;
+ if ( ! symbolicSizes.containsKey(onnxDimension.getSymbolicName())) {
+ return true;
+ }
+ } else if (onnxDimension.getSize() == 0) {
+ return true;
+ }
+ }
+ return false;
+ }
+
@Override
public String toString() {
return "(" + valueType.id() + ")" +
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java
new file mode 100644
index 00000000000..3a77a349af3
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java
@@ -0,0 +1,151 @@
+package com.yahoo.vespa.model.ml;
+
+import com.fasterxml.jackson.core.JsonEncoding;
+import com.fasterxml.jackson.core.JsonFactory;
+import com.fasterxml.jackson.core.JsonGenerator;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
+import com.yahoo.tensor.TensorType;
+
+import java.io.BufferedReader;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.Map;
+
+/**
+ * Defers to 'vespa-analyze-onnx-model' to determine the output type given
+ * a set of inputs. For situations with symbolic dimension sizes that can't
+ * easily be determined.
+ *
+ * @author lesters
+ */
+public class OnnxModelProbe {
+
+ private static final String binary = "vespa-analyze-onnx-model";
+
+ static TensorType probeModel(ApplicationPackage app, Path modelPath, String outputName, Map<String, TensorType> inputTypes) {
+ TensorType outputType = TensorType.empty;
+ String contextKey = createContextKey(outputName, inputTypes);
+
+ try {
+ // Check if output type has already been probed
+ outputType = readProbedOutputType(app, modelPath, contextKey);
+
+ // Otherwise, run vespa-analyze-onnx-model if the model is available
+ if (outputType.equals(TensorType.empty) && app.getFile(modelPath).exists()) {
+ String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes);
+ String jsonOutput = callVespaAnalyzeOnnxModel(jsonInput);
+ outputType = outputTypeFromJson(jsonOutput, outputName);
+ writeProbedOutputType(app, modelPath, contextKey, outputType);
+ }
+
+ } catch (IOException | InterruptedException e) {
+ throw new IllegalArgumentException("Unable to probe ONNX model", e);
+ }
+
+ return outputType;
+ }
+
+ private static String createContextKey(String onnxName, Map<String, TensorType> inputTypes) {
+ StringBuilder key = new StringBuilder().append(onnxName).append(":");
+ inputTypes.entrySet().stream().sorted(Map.Entry.comparingByKey())
+ .forEachOrdered(e -> key.append(e.getKey()).append(":").append(e.getValue()).append(","));
+ return key.substring(0, key.length()-1);
+ }
+
+ private static Path probedOutputTypesPath(Path path) {
+ String fileName = OnnxModelInfo.asValidIdentifier(path.getRelative()) + ".probed_output_types";
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName);
+ }
+
+ static void writeProbedOutputType(ApplicationPackage app, Path modelPath, String output,
+ Map<String, TensorType> inputTypes, TensorType type) throws IOException {
+ writeProbedOutputType(app, modelPath, createContextKey(output, inputTypes), type);
+ }
+
+ private static void writeProbedOutputType(ApplicationPackage app, Path modelPath,
+ String contextKey, TensorType type) throws IOException {
+ String path = app.getFileReference(probedOutputTypesPath(modelPath)).getAbsolutePath();
+ IOUtils.writeFile(path, contextKey + "\t" + type + "\n", true);
+ }
+
+ private static TensorType readProbedOutputType(ApplicationPackage app, Path modelPath,
+ String contextKey) throws IOException {
+ ApplicationFile file = app.getFile(probedOutputTypesPath(modelPath));
+ if ( ! file.exists()) {
+ return TensorType.empty;
+ }
+ try (BufferedReader reader = new BufferedReader(file.createReader())) {
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ String key = parts[0];
+ if (key.equals(contextKey)) {
+ return TensorType.fromSpec(parts[1]);
+ }
+ }
+ }
+ return TensorType.empty;
+ }
+
+ private static TensorType outputTypeFromJson(String json, String outputName) throws IOException {
+ ObjectMapper m = new ObjectMapper();
+ JsonNode root = m.readTree(json);
+ if ( ! root.isObject() || ! root.has("outputs")) {
+ return TensorType.empty;
+ }
+ JsonNode outputs = root.get("outputs");
+ if ( ! outputs.has(outputName)) {
+ return TensorType.empty;
+ }
+ return TensorType.fromSpec(outputs.get(outputName).asText());
+ }
+
+ private static String createJsonInput(String modelPath, Map<String, TensorType> inputTypes) throws IOException {
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8);
+ g.writeStartObject();
+ g.writeStringField("model", modelPath);
+ g.writeObjectFieldStart("inputs");
+ for (Map.Entry<String, TensorType> input : inputTypes.entrySet()) {
+ g.writeStringField(input.getKey(), input.getValue().toString());
+ }
+ g.writeEndObject();
+ g.writeEndObject();
+ g.close();
+ return out.toString();
+ }
+
+ private static String callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException {
+ ProcessBuilder processBuilder = new ProcessBuilder(binary, "--probe-types");
+ processBuilder.redirectErrorStream(true);
+ StringBuilder output = new StringBuilder();
+ Process process = processBuilder.start(); // kaster en exception dersom executable ikke finnes
+
+ // Write json array to process stdin
+ OutputStream os = process.getOutputStream();
+ os.write(jsonInput.getBytes(StandardCharsets.UTF_8));
+ os.close();
+
+ // Read output from stdout/stderr
+ InputStream inputStream = process.getInputStream();
+ while (true) {
+ int b = inputStream.read();
+ if (b == -1) break;
+ output.append((char)b);
+ }
+ int returnCode = process.waitFor();
+ if (returnCode != 0) {
+ throw new IllegalArgumentException("Error from '" + binary + "'. Return code: " + returnCode + ". Output:\n" + output);
+ }
+ return output.toString();
+ }
+
+}
diff --git a/config-model/src/test/cfg/application/onnx_probe/files/create_dynamic_model.py b/config-model/src/test/cfg/application/onnx_probe/files/create_dynamic_model.py
new file mode 100755
index 00000000000..b493e394ee4
--- /dev/null
+++ b/config-model/src/test/cfg/application/onnx_probe/files/create_dynamic_model.py
@@ -0,0 +1,19 @@
+# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+import numpy as np
+from onnx import helper, TensorProto
+
+INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, ["batch", 2])
+INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, ["batch", 2])
+OUTPUT = helper.make_tensor_value_info('out', TensorProto.FLOAT, ["batch", "dim1", "dim2"])
+
+SHAPE = helper.make_tensor('shape', TensorProto.INT64, dims=[3], vals=np.array([1,2,2]).astype(np.int64))
+
+nodes = [
+ helper.make_node('Concat', ['input1', 'input2'], ['concat'], axis=1),
+ helper.make_node('Reshape', ['concat', 'shape'], ['out']),
+]
+graph_def = helper.make_graph(nodes, 'simple_scoring', [INPUT_1, INPUT_2], [OUTPUT], [SHAPE])
+model_def = helper.make_model(graph_def, producer_name='create_dynamic_model.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
+onnx.save(model_def, 'dynamic_model_2.onnx')
diff --git a/config-model/src/test/cfg/application/onnx_probe/files/dynamic_model.onnx b/config-model/src/test/cfg/application/onnx_probe/files/dynamic_model.onnx
new file mode 100644
index 00000000000..28c600e2a09
--- /dev/null
+++ b/config-model/src/test/cfg/application/onnx_probe/files/dynamic_model.onnx
@@ -0,0 +1,21 @@
+create_dynamic_model_2.py:Ö
+-
+input1
+input2concat"Concat*
+axis 
+
+concat
+shapeout"Reshapesimple_scoring*:BshapeZ
+input1
+
+batch
+Z
+input2
+
+batch
+b&
+out
+
+batch
+dim1
+dim2B \ No newline at end of file
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java
new file mode 100644
index 00000000000..55e2f6f018c
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java
@@ -0,0 +1,50 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.ml;
+
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.application.provider.FilesApplicationPackage;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+public class OnnxModelProbeTest {
+
+ @Test
+ public void testProbedOutputTypes() throws IOException {
+
+ Path appDir = Path.fromString("src/test/cfg/application/onnx_probe");
+ Path storedAppDir = appDir.append("copy");
+ try {
+ FilesApplicationPackage app = FilesApplicationPackage.fromFile(appDir.toFile());
+ Path modelPath = Path.fromString("files/dynamic_model.onnx");
+ String output = "out";
+ Map<String, TensorType> inputTypes = Map.of(
+ "input1", TensorType.fromSpec("tensor<float>(d0[1],d1[2])"),
+ "input2", TensorType.fromSpec("tensor<float>(d0[1],d1[2])"));
+ TensorType expected = TensorType.fromSpec("tensor<float>(d0[1],d1[2],d2[2])");
+
+ // Can't test model probing directly as 'vespa-analyze-onnx-model' is not available
+
+ OnnxModelProbe.writeProbedOutputType(app, modelPath, output, inputTypes, expected);
+
+ // Test loading from generated info
+ storedAppDir.toFile().mkdirs();
+ IOUtils.copyDirectory(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ app = FilesApplicationPackage.fromFile(storedAppDir.toFile());
+ TensorType outputType = OnnxModelProbe.probeModel(app, modelPath, output, inputTypes);
+ assertEquals(outputType, expected);
+
+ } finally {
+ IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ IOUtils.recursiveDeleteDir(storedAppDir.toFile());
+ }
+ }
+
+}