diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/OnnxModel.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/OnnxModel.java | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index 272b668b5fb..90a27d1f036 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -1,15 +1,17 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.ml.OnnxModelInfo; -import com.yahoo.searchlib.rankingexpression.Reference; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; /** * A global ONNX model distributed using file distribution, similar to ranking constants. @@ -21,6 +23,7 @@ public class OnnxModel extends DistributableResource { private OnnxModelInfo modelInfo = null; private final Map<String, String> inputMap = new HashMap<>(); private final Map<String, String> outputMap = new HashMap<>(); + private final Set<String> initializers = new HashSet<>(); private String statelessExecutionMode = null; private Integer statelessInterOpThreads = null; @@ -101,11 +104,13 @@ public class OnnxModel extends DistributableResource { for (String onnxName : modelInfo.getOutputs()) { addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); } + initializers.addAll(modelInfo.getInitializers()); this.modelInfo = modelInfo; } public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); } public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); } + public Set<String> getInitializers() { return Set.copyOf(initializers); } public String getDefaultOutput() { return modelInfo != null ? modelInfo.getDefaultOutput() : ""; |