summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2022-05-19 12:03:06 +0200
committerJon Bratseth <bratseth@gmail.com>2022-05-19 12:03:06 +0200
commit5c24dc5c9642a8d9ed70aee4c950fd0678a1ebec (patch)
treebd9b74bf00c832456f0b83c1b2cd7010be387d68 /config-model/src/main/java/com/yahoo/schema/OnnxModel.java
parentf17c4fe7de4c55f5c4ee61897eab8c2f588d8405 (diff)
Rename the 'searchdefinition' package to 'schema'
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/OnnxModel.java')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/OnnxModel.java120
1 files changed, 120 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
new file mode 100644
index 00000000000..26a0b3e595d
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
@@ -0,0 +1,120 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.schema;
+
+import com.yahoo.tensor.TensorType;
+import com.yahoo.vespa.model.ml.OnnxModelInfo;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+
+/**
+ * A global ONNX model distributed using file distribution, similar to ranking constants.
+ *
+ * @author lesters
+ */
+public class OnnxModel extends DistributableResource {
+
+ private OnnxModelInfo modelInfo = null;
+ private final Map<String, String> inputMap = new HashMap<>();
+ private final Map<String, String> outputMap = new HashMap<>();
+
+ private String statelessExecutionMode = null;
+ private Integer statelessInterOpThreads = null;
+ private Integer statelessIntraOpThreads = null;
+
+ public OnnxModel(String name) {
+ super(name);
+ }
+
+ public OnnxModel(String name, String fileName) {
+ super(name, fileName);
+ validate();
+ }
+
+ @Override
+ public void setUri(String uri) {
+ throw new IllegalArgumentException("URI for ONNX models are not currently supported");
+ }
+
+ public void addInputNameMapping(String onnxName, String vespaName) {
+ addInputNameMapping(onnxName, vespaName, true);
+ }
+
+ public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) {
+ Objects.requireNonNull(onnxName, "Onnx name cannot be null");
+ Objects.requireNonNull(vespaName, "Vespa name cannot be null");
+ if (overwrite || ! inputMap.containsKey(onnxName)) {
+ inputMap.put(onnxName, vespaName);
+ }
+ }
+
+ public void addOutputNameMapping(String onnxName, String vespaName) {
+ addOutputNameMapping(onnxName, vespaName, true);
+ }
+
+ public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) {
+ Objects.requireNonNull(onnxName, "Onnx name cannot be null");
+ Objects.requireNonNull(vespaName, "Vespa name cannot be null");
+ if (overwrite || ! outputMap.containsKey(onnxName)) {
+ outputMap.put(onnxName, vespaName);
+ }
+ }
+
+ public void setModelInfo(OnnxModelInfo modelInfo) {
+ Objects.requireNonNull(modelInfo, "Onnx model info cannot be null");
+ for (String onnxName : modelInfo.getInputs()) {
+ addInputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
+ }
+ for (String onnxName : modelInfo.getOutputs()) {
+ addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
+ }
+ this.modelInfo = modelInfo;
+ }
+
+ public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); }
+ public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); }
+
+ public String getDefaultOutput() {
+ return modelInfo != null ? modelInfo.getDefaultOutput() : "";
+ }
+
+ TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) {
+ return modelInfo != null ? modelInfo.getTensorType(onnxName, inputTypes) : TensorType.empty;
+ }
+
+ public void setStatelessExecutionMode(String executionMode) {
+ if ("parallel".equalsIgnoreCase(executionMode)) {
+ this.statelessExecutionMode = "parallel";
+ } else if ("sequential".equalsIgnoreCase(executionMode)) {
+ this.statelessExecutionMode = "sequential";
+ }
+ }
+
+ public Optional<String> getStatelessExecutionMode() {
+ return Optional.ofNullable(statelessExecutionMode);
+ }
+
+ public void setStatelessInterOpThreads(int interOpThreads) {
+ if (interOpThreads >= 0) {
+ this.statelessInterOpThreads = interOpThreads;
+ }
+ }
+
+ public Optional<Integer> getStatelessInterOpThreads() {
+ return Optional.ofNullable(statelessInterOpThreads);
+ }
+
+ public void setStatelessIntraOpThreads(int intraOpThreads) {
+ if (intraOpThreads >= 0) {
+ this.statelessIntraOpThreads = intraOpThreads;
+ }
+ }
+
+ public Optional<Integer> getStatelessIntraOpThreads() {
+ return Optional.ofNullable(statelessIntraOpThreads);
+ }
+
+}