summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java38
1 files changed, 38 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
index 3c42987512b..d85c0065d84 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
@@ -8,6 +8,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
+import java.util.Optional;
/**
* A global ONNX model distributed using file distribution, similar to ranking constants.
@@ -20,6 +21,10 @@ public class OnnxModel extends DistributableResource {
private Map<String, String> inputMap = new HashMap<>();
private Map<String, String> outputMap = new HashMap<>();
+ private String statelessExecutionMode = null;
+ private Integer statelessInterOpThreads = null;
+ private Integer statelessIntraOpThreads = null;
+
public OnnxModel(String name) {
super(name);
}
@@ -79,4 +84,37 @@ public class OnnxModel extends DistributableResource {
TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) {
return modelInfo != null ? modelInfo.getTensorType(onnxName, inputTypes) : TensorType.empty;
}
+
+ public void setStatelessExecutionMode(String executionMode) {
+ if ("parallel".equalsIgnoreCase(executionMode)) {
+ this.statelessExecutionMode = "parallel";
+ } else if ("sequential".equalsIgnoreCase(executionMode)) {
+ this.statelessExecutionMode = "sequential";
+ }
+ }
+
+ public Optional<String> getStatelessExecutionMode() {
+ return Optional.ofNullable(statelessExecutionMode);
+ }
+
+ public void setStatelessInterOpThreads(int interOpThreads) {
+ if (interOpThreads >= 0) {
+ this.statelessInterOpThreads = interOpThreads;
+ }
+ }
+
+ public Optional<Integer> getStatelessInterOpThreads() {
+ return Optional.ofNullable(statelessInterOpThreads);
+ }
+
+ public void setStatelessIntraOpThreads(int intraOpThreads) {
+ if (intraOpThreads >= 0) {
+ this.statelessIntraOpThreads = intraOpThreads;
+ }
+ }
+
+ public Optional<Integer> getStatelessIntraOpThreads() {
+ return Optional.ofNullable(statelessIntraOpThreads);
+ }
+
}