diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index 3c42987512b..d85c0065d84 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -8,6 +8,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Objects; +import java.util.Optional; /** * A global ONNX model distributed using file distribution, similar to ranking constants. @@ -20,6 +21,10 @@ public class OnnxModel extends DistributableResource { private Map<String, String> inputMap = new HashMap<>(); private Map<String, String> outputMap = new HashMap<>(); + private String statelessExecutionMode = null; + private Integer statelessInterOpThreads = null; + private Integer statelessIntraOpThreads = null; + public OnnxModel(String name) { super(name); } @@ -79,4 +84,37 @@ public class OnnxModel extends DistributableResource { TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) { return modelInfo != null ? modelInfo.getTensorType(onnxName, inputTypes) : TensorType.empty; } + + public void setStatelessExecutionMode(String executionMode) { + if ("parallel".equalsIgnoreCase(executionMode)) { + this.statelessExecutionMode = "parallel"; + } else if ("sequential".equalsIgnoreCase(executionMode)) { + this.statelessExecutionMode = "sequential"; + } + } + + public Optional<String> getStatelessExecutionMode() { + return Optional.ofNullable(statelessExecutionMode); + } + + public void setStatelessInterOpThreads(int interOpThreads) { + if (interOpThreads >= 0) { + this.statelessInterOpThreads = interOpThreads; + } + } + + public Optional<Integer> getStatelessInterOpThreads() { + return Optional.ofNullable(statelessInterOpThreads); + } + + public void setStatelessIntraOpThreads(int intraOpThreads) { + if (intraOpThreads >= 0) { + this.statelessIntraOpThreads = intraOpThreads; + } + } + + public Optional<Integer> getStatelessIntraOpThreads() { + return Optional.ofNullable(statelessIntraOpThreads); + } + } |