aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/vespa/model/ml
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-10-29 15:31:14 +0100
committerLester Solbakken <lesters@oath.com>2020-10-29 15:31:14 +0100
commit2425b3562e5a84e6caf44228712cfade4c8583a2 (patch)
tree6ea1dff3bd3ba0ea4accb91d5e30447eddaf030a /config-model/src/main/java/com/yahoo/vespa/model/ml
parentfbd8ca020a6d97882c554585d911889d1b9f69ea (diff)
Store generated model info for ZK
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/ml')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java389
1 files changed, 389 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
new file mode 100644
index 00000000000..7526a8a8595
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
@@ -0,0 +1,389 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.ml;
+
+import com.fasterxml.jackson.core.JsonEncoding;
+import com.fasterxml.jackson.core.JsonFactory;
+import com.fasterxml.jackson.core.JsonGenerator;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
+import com.yahoo.tensor.TensorType;
+import onnx.Onnx;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Model information (input and output types) for an ONNX model.
+ * This encapsulates the difference between reading ONNX model information
+ * - from a file application package, where we can read the ONNX model directly
+ * - from a ZK application package, where the file is unavailable and models are read from
+ * generated files stored in file distribution or ZooKeeper.
+ *
+ * @author lesters
+ */
+public class OnnxModelInfo {
+
+ private final String defaultOutput;
+ private final Map<String, OnnxTypeInfo> inputs;
+ private final Map<String, OnnxTypeInfo> outputs;
+ private final Map<String, TensorType> vespaTypes = new HashMap<>();
+
+ private OnnxModelInfo(Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) {
+ this.inputs = Collections.unmodifiableMap(inputs);
+ this.outputs = Collections.unmodifiableMap(outputs);
+ this.defaultOutput = defaultOutput;
+ }
+
+ public Set<String> getInputs() {
+ return inputs.keySet();
+ }
+
+ public Set<String> getOutputs() {
+ return outputs.keySet();
+ }
+
+ public String getDefaultOutput() {
+ return defaultOutput;
+ }
+
+ /**
+ * Return the tensor type for an ONNX model for the given context.
+ * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output
+ * type depends on the input types for the given context (rank profile).
+ */
+ public TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) {
+ OnnxTypeInfo onnxTypeInfo = outputs.get(onnxName);
+ if (onnxTypeInfo == null) {
+ throw new IllegalArgumentException("Could not find type for output '" + onnxName + "'");
+ }
+ if (onnxTypeInfo.containsUnknownDimensionSizes()) {
+ Set<Long> unboundSizes = new HashSet<>();
+ Map<String, Long> symbolicSizes = new HashMap<>();
+ resolveUnknownDimensionSizes(inputTypes, symbolicSizes, unboundSizes);
+ return onnxTypeInfo.toVespaTensorType(symbolicSizes, unboundSizes);
+ }
+ return vespaTypes.computeIfAbsent(onnxName, v -> onnxTypeInfo.toVespaTensorType());
+ }
+
+ private void resolveUnknownDimensionSizes(Map<String, TensorType> inputTypes,
+ Map<String, Long> symbolicSizes,
+ Set<Long> unboundSizes)
+ {
+ for (Map.Entry<String, OnnxTypeInfo> input : inputs.entrySet()) {
+ String onnxName = input.getKey();
+ OnnxTypeInfo onnxType = input.getValue();
+ TensorType vespaType = inputTypes.get(onnxName);
+ if (vespaType == null || vespaType.dimensions().size() != onnxType.dimensions().size()) {
+ continue;
+ }
+
+ for (int i = 0; i < vespaType.dimensions().size(); ++i) {
+ if (vespaType.dimensions().get(i).size().isEmpty()) {
+ continue;
+ }
+ Long size = vespaType.dimensions().get(i).size().get();
+
+ // Handle dimensions with size -1 - typically batch dimensions
+ if (onnxType.dimensions().get(i).getSize() == -1) {
+ unboundSizes.add(size);
+ if (unboundSizes.size() > 1) {
+ throw new IllegalArgumentException("Found conflicting sizes for unbound dimension " +
+ "for type '" + onnxType + "'");
+ }
+
+ // Handle dimensions with symbolic names
+ } else if (onnxType.dimensions().get(i).hasSymbolicName()) {
+ String symbolicName = onnxType.dimensions().get(i).getSymbolicName();
+ if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) {
+ throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension '" +
+ symbolicName + "' for input '" + onnxName + "'");
+ }
+ symbolicSizes.put(symbolicName, size);
+ }
+ }
+ }
+ }
+
+ static public OnnxModelInfo load(String path, ApplicationPackage app) {
+ Path pathInApplicationPackage = Path.fromString(path);
+ if (app.getFile(pathInApplicationPackage).exists()) {
+ return loadFromFile(pathInApplicationPackage, app);
+ }
+ if (app.getFile(generatedModelInfoPath(pathInApplicationPackage)).exists()) {
+ return loadFromGeneratedInfo(pathInApplicationPackage, app);
+ }
+ throw new IllegalArgumentException("Unable to find ONNX model file or generated ONNX info file");
+ }
+
+ static public boolean modelExists(String path, ApplicationPackage app) {
+ Path pathInApplicationPackage = Path.fromString(path);
+ if (app.getFile(pathInApplicationPackage).exists()) {
+ return true;
+ }
+ if (app.getFile(generatedModelInfoPath(Path.fromString(path))).exists()) {
+ return true;
+ }
+ return false;
+ }
+
+ static private OnnxModelInfo loadFromFile(Path path, ApplicationPackage app) {
+ try (InputStream inputStream = app.getFile(path).createInputStream()) {
+ Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
+ String json = onnxModelToJson(model);
+ storeGeneratedInfo(json, path, app);
+ return jsonToModelInfo(json);
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Unable to parse ONNX model", e);
+ }
+ }
+
+ static private OnnxModelInfo loadFromGeneratedInfo(Path path, ApplicationPackage app) {
+ try {
+ String json = readGeneratedInfo(path, app);
+ return jsonToModelInfo(json);
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Unable to parse ONNX model", e);
+ }
+ }
+
+ static private String readGeneratedInfo(Path path, ApplicationPackage app) throws IOException {
+ ApplicationFile file = app.getFile(generatedModelInfoPath(path));
+ return IOUtils.readAll(file.createReader());
+ }
+
+ static private void storeGeneratedInfo(String json, Path path, ApplicationPackage app) throws IOException {
+ IOUtils.writeFile(app.getFileReference(generatedModelInfoPath(path)), json, false);
+ }
+
+ static private Path generatedModelInfoPath(Path path) {
+ String fileName = asValidIdentifier(path.getRelative()) + ".modelinfo.json";
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName);
+ }
+
+ static private String onnxModelToJson(Onnx.ModelProto model) throws IOException {
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8);
+ g.writeStartObject();
+
+ g.writeArrayFieldStart("inputs");
+ for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) {
+ onnxTypeToJson(g, valueInfo);
+ }
+ g.writeEndArray();
+
+ g.writeArrayFieldStart("outputs");
+ for (Onnx.ValueInfoProto valueInfo : model.getGraph().getOutputList()) {
+ onnxTypeToJson(g, valueInfo);
+ }
+ g.writeEndArray();
+
+ g.writeEndObject();
+ g.close();
+ return out.toString();
+ }
+
+ static public OnnxModelInfo jsonToModelInfo(String json) throws IOException {
+ ObjectMapper m = new ObjectMapper();
+ JsonNode root = m.readTree(json);
+ Map<String, OnnxTypeInfo> inputs = new HashMap<>();
+ Map<String, OnnxTypeInfo> outputs = new HashMap<>();
+ String defaultOutput = "";
+
+ for (JsonNode input : root.get("inputs")) {
+ inputs.put(input.get("name").textValue(), jsonToTypeInfo(input));
+ }
+ for (JsonNode output : root.get("outputs")) {
+ outputs.put(output.get("name").textValue(), jsonToTypeInfo(output));
+ }
+ if (root.get("outputs").has(0)) {
+ defaultOutput = root.get("outputs").get(0).get("name").textValue();
+ }
+ return new OnnxModelInfo(inputs, outputs, defaultOutput);
+ }
+
+ static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException {
+ g.writeStartObject();
+ g.writeStringField("name", valueInfo.getName());
+ g.writeStringField("type", onnxValueTypeToString(valueInfo.getType().getTensorType().getElemType()));
+ g.writeArrayFieldStart("dim");
+ for (Onnx.TensorShapeProto.Dimension dim : valueInfo.getType().getTensorType().getShape().getDimList()) {
+ g.writeStartObject();
+ if (dim.hasDimParam()) {
+ g.writeStringField("type", "param");
+ g.writeStringField("size", dim.getDimParam());
+ } else {
+ g.writeStringField("type", "value");
+ g.writeNumberField("size", dim.getDimValue());
+ }
+ g.writeEndObject();
+ }
+ g.writeEndArray();
+ g.writeEndObject();
+ }
+
+ static private OnnxTypeInfo jsonToTypeInfo(JsonNode node) {
+ TensorType.Value valueType = stringToValueType(node.get("type").textValue());
+ OnnxTypeInfo type = new OnnxTypeInfo(valueType);
+ for (JsonNode dim : node.get("dim")) {
+ if (dim.get("type").textValue().equals("param")) {
+ type.addDimension(dim.get("size").textValue());
+ } else {
+ type.addDimension(dim.get("size").longValue());
+ }
+ }
+ return type;
+ }
+
+ private static String onnxValueTypeToString(Onnx.TensorProto.DataType dataType) {
+ switch (dataType) {
+ case FLOAT: return "float";
+ case DOUBLE: return "double";
+ // Imperfect conversion, for now:
+ case BOOL: return "float";
+ case INT8: return "float";
+ case INT16: return "float";
+ case INT32: return "float";
+ case INT64: return "float";
+ case UINT8: return "float";
+ case UINT16: return "float";
+ case UINT32: return "float";
+ case UINT64: return "float";
+ default:
+ throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
+ " cannot be converted to a Vespa tensor type");
+ }
+ }
+
+ private static TensorType.Value stringToValueType(String type) {
+ switch (type) {
+ case "float": return TensorType.Value.FLOAT;
+ case "double": return TensorType.Value.DOUBLE;
+ default:
+ throw new IllegalArgumentException("Unknown tensor value type: " + type);
+ }
+ }
+
+ public static String asValidIdentifier(String str) {
+ return str.replaceAll("[^\\w\\d\\$@_]", "_");
+ }
+
+
+ private static class OnnxTypeInfo {
+ private final TensorType.Value valueType;
+ private final List<OnnxDimensionInfo> dimensions = new ArrayList<>();
+
+ OnnxTypeInfo(TensorType.Value valueType) {
+ this.valueType = valueType;
+ }
+
+ void addDimension(long value) {
+ dimensions.add(new OnnxDimensionInfo(value));
+ }
+
+ void addDimension(String param) {
+ dimensions.add(new OnnxDimensionInfo(param));
+ }
+
+ boolean containsUnknownDimensionSizes() {
+ return dimensions.stream().anyMatch(OnnxDimensionInfo::unknownDimensionSize);
+ }
+
+ TensorType.Value valueType() {
+ return valueType;
+ }
+
+ List<OnnxDimensionInfo> dimensions() {
+ return dimensions;
+ }
+
+ TensorType toVespaTensorType() {
+ return toVespaTensorType(null, null);
+ }
+
+ TensorType toVespaTensorType(Map<String, Long> symbolicSizes, Set<Long> unboundSizes) {
+ String dimensionPrefix = "d"; // standard naming convention: d0, d1, ...
+ TensorType.Builder builder = new TensorType.Builder(valueType);
+ for (int i = 0; i < dimensions.size(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ OnnxDimensionInfo onnxDimension = dimensions.get(i);
+ long onnxDimensionSize = onnxDimension.getSize();
+ if (onnxDimension.hasSymbolicName() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getSymbolicName())) {
+ onnxDimensionSize = symbolicSizes.get(onnxDimension.getSymbolicName());
+ }
+ if (onnxDimensionSize == 0 && symbolicSizes != null) {
+ // This is for the case where all symbolic dimensions have
+ // different names, but can be resolved to a single dimension size.
+ Set<Long> unknownSizes = new HashSet<>(symbolicSizes.values());
+ if (unknownSizes.size() == 1) {
+ onnxDimensionSize = unknownSizes.iterator().next();
+ }
+ }
+ if (onnxDimensionSize < 0 && unboundSizes != null && unboundSizes.size() > 0) {
+ onnxDimensionSize = unboundSizes.iterator().next();
+ }
+ if (onnxDimensionSize <= 0) {
+ return TensorType.empty; // Unable to determine type - probably out of context
+ }
+ builder.indexed(dimensionName, onnxDimensionSize);
+ }
+ return builder.build();
+ }
+
+ @Override
+ public String toString() {
+ return "(" + valueType.id() + ")" +
+ "[" + dimensions.stream().map(OnnxDimensionInfo::toString).collect(Collectors.joining(",")) + "]";
+ }
+
+ }
+
+ private static class OnnxDimensionInfo {
+ private final long size;
+ private final String symbolicName;
+
+ OnnxDimensionInfo(long size) {
+ this.size = size;
+ this.symbolicName = null;
+ }
+
+ OnnxDimensionInfo(String symbolicName) {
+ this.size = 0;
+ this.symbolicName = symbolicName;
+ }
+
+ long getSize() {
+ return size;
+ }
+
+ String getSymbolicName() {
+ return symbolicName;
+ }
+
+ boolean hasSymbolicName() {
+ return symbolicName != null;
+ }
+
+ boolean unknownDimensionSize() {
+ return hasSymbolicName() || size <= 0;
+ }
+
+ @Override
+ public String toString() {
+ return hasSymbolicName() ? "\"" + symbolicName + "\"" : Long.toString(size);
+ }
+ }
+
+}