summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java24
1 files changed, 22 insertions, 2 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
index 2742dc59fcd..7c89a349d7d 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
@@ -42,13 +42,16 @@ public class OnnxModelInfo {
private final Map<String, OnnxTypeInfo> inputs;
private final Map<String, OnnxTypeInfo> outputs;
private final Map<String, TensorType> vespaTypes = new HashMap<>();
+ private final Set<String> initializers;
- private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) {
+ private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs,
+ Map<String, OnnxTypeInfo> outputs, Set<String> initializers, String defaultOutput) {
this.app = app;
this.modelPath = path;
this.inputs = Collections.unmodifiableMap(inputs);
this.outputs = Collections.unmodifiableMap(outputs);
this.defaultOutput = defaultOutput;
+ this.initializers = Set.copyOf(initializers);
}
public String getModelPath() {
@@ -63,6 +66,8 @@ public class OnnxModelInfo {
return outputs.keySet();
}
+ public Set<String> getInitializers() { return initializers; }
+
public String getDefaultOutput() {
return defaultOutput;
}
@@ -208,6 +213,14 @@ public class OnnxModelInfo {
}
g.writeEndArray();
+ g.writeArrayFieldStart("initializers");
+ for (Onnx.TensorProto initializers : model.getGraph().getInitializerList()) {
+ g.writeStartObject();
+ g.writeStringField("name", initializers.getName());
+ g.writeEndObject();
+ }
+ g.writeEndArray();
+
g.writeEndObject();
g.close();
return out.toString();
@@ -218,6 +231,7 @@ public class OnnxModelInfo {
JsonNode root = m.readTree(json);
Map<String, OnnxTypeInfo> inputs = new HashMap<>();
Map<String, OnnxTypeInfo> outputs = new HashMap<>();
+ Set<String> initializers = new HashSet<>();
String defaultOutput = "";
String path = null;
@@ -233,7 +247,13 @@ public class OnnxModelInfo {
if (root.get("outputs").has(0)) {
defaultOutput = root.get("outputs").get(0).get("name").textValue();
}
- return new OnnxModelInfo(app, path, inputs, outputs, defaultOutput);
+ var initializerRoot = root.get("initializers");
+ if (initializerRoot != null) {
+ for (JsonNode initializer : initializerRoot) {
+ initializers.add(initializer.get("name").textValue());
+ }
+ }
+ return new OnnxModelInfo(app, path, inputs, outputs, initializers, defaultOutput);
}
static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException {