diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java index 2742dc59fcd..7c89a349d7d 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java @@ -42,13 +42,16 @@ public class OnnxModelInfo { private final Map<String, OnnxTypeInfo> inputs; private final Map<String, OnnxTypeInfo> outputs; private final Map<String, TensorType> vespaTypes = new HashMap<>(); + private final Set<String> initializers; - private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) { + private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs, + Map<String, OnnxTypeInfo> outputs, Set<String> initializers, String defaultOutput) { this.app = app; this.modelPath = path; this.inputs = Collections.unmodifiableMap(inputs); this.outputs = Collections.unmodifiableMap(outputs); this.defaultOutput = defaultOutput; + this.initializers = Set.copyOf(initializers); } public String getModelPath() { @@ -63,6 +66,8 @@ public class OnnxModelInfo { return outputs.keySet(); } + public Set<String> getInitializers() { return initializers; } + public String getDefaultOutput() { return defaultOutput; } @@ -208,6 +213,14 @@ public class OnnxModelInfo { } g.writeEndArray(); + g.writeArrayFieldStart("initializers"); + for (Onnx.TensorProto initializers : model.getGraph().getInitializerList()) { + g.writeStartObject(); + g.writeStringField("name", initializers.getName()); + g.writeEndObject(); + } + g.writeEndArray(); + g.writeEndObject(); g.close(); return out.toString(); @@ -218,6 +231,7 @@ public class OnnxModelInfo { JsonNode root = m.readTree(json); Map<String, OnnxTypeInfo> inputs = new HashMap<>(); Map<String, OnnxTypeInfo> outputs = new HashMap<>(); + Set<String> initializers = new HashSet<>(); String defaultOutput = ""; String path = null; @@ -233,7 +247,13 @@ public class OnnxModelInfo { if (root.get("outputs").has(0)) { defaultOutput = root.get("outputs").get(0).get("name").textValue(); } - return new OnnxModelInfo(app, path, inputs, outputs, defaultOutput); + var initializerRoot = root.get("initializers"); + if (initializerRoot != null) { + for (JsonNode initializer : initializerRoot) { + initializers.add(initializer.get("name").textValue()); + } + } + return new OnnxModelInfo(app, path, inputs, outputs, initializers, defaultOutput); } static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException { |