summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/OnnxModel.java')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/OnnxModel.java7
1 files changed, 6 insertions, 1 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
index 272b668b5fb..90a27d1f036 100644
--- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
@@ -1,15 +1,17 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.ml.OnnxModelInfo;
-import com.yahoo.searchlib.rankingexpression.Reference;
import java.util.Collections;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
+import java.util.Set;
/**
* A global ONNX model distributed using file distribution, similar to ranking constants.
@@ -21,6 +23,7 @@ public class OnnxModel extends DistributableResource {
private OnnxModelInfo modelInfo = null;
private final Map<String, String> inputMap = new HashMap<>();
private final Map<String, String> outputMap = new HashMap<>();
+ private final Set<String> initializers = new HashSet<>();
private String statelessExecutionMode = null;
private Integer statelessInterOpThreads = null;
@@ -101,11 +104,13 @@ public class OnnxModel extends DistributableResource {
for (String onnxName : modelInfo.getOutputs()) {
addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
}
+ initializers.addAll(modelInfo.getInitializers());
this.modelInfo = modelInfo;
}
public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); }
public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); }
+ public Set<String> getInitializers() { return Set.copyOf(initializers); }
public String getDefaultOutput() {
return modelInfo != null ? modelInfo.getDefaultOutput() : "";