aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java38
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java1
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java11
3 files changed, 50 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
index 3c42987512b..d85c0065d84 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
@@ -8,6 +8,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
+import java.util.Optional;
/**
* A global ONNX model distributed using file distribution, similar to ranking constants.
@@ -20,6 +21,10 @@ public class OnnxModel extends DistributableResource {
private Map<String, String> inputMap = new HashMap<>();
private Map<String, String> outputMap = new HashMap<>();
+ private String statelessExecutionMode = null;
+ private Integer statelessInterOpThreads = null;
+ private Integer statelessIntraOpThreads = null;
+
public OnnxModel(String name) {
super(name);
}
@@ -79,4 +84,37 @@ public class OnnxModel extends DistributableResource {
TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) {
return modelInfo != null ? modelInfo.getTensorType(onnxName, inputTypes) : TensorType.empty;
}
+
+ public void setStatelessExecutionMode(String executionMode) {
+ if ("parallel".equalsIgnoreCase(executionMode)) {
+ this.statelessExecutionMode = "parallel";
+ } else if ("sequential".equalsIgnoreCase(executionMode)) {
+ this.statelessExecutionMode = "sequential";
+ }
+ }
+
+ public Optional<String> getStatelessExecutionMode() {
+ return Optional.ofNullable(statelessExecutionMode);
+ }
+
+ public void setStatelessInterOpThreads(int interOpThreads) {
+ if (interOpThreads >= 0) {
+ this.statelessInterOpThreads = interOpThreads;
+ }
+ }
+
+ public Optional<Integer> getStatelessInterOpThreads() {
+ return Optional.ofNullable(statelessInterOpThreads);
+ }
+
+ public void setStatelessIntraOpThreads(int intraOpThreads) {
+ if (intraOpThreads >= 0) {
+ this.statelessIntraOpThreads = intraOpThreads;
+ }
+ }
+
+ public Optional<Integer> getStatelessIntraOpThreads() {
+ return Optional.ofNullable(statelessIntraOpThreads);
+ }
+
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
index cb61b8c9cec..dcbea70a32b 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
@@ -22,6 +22,7 @@ public class OnnxModels {
public OnnxModels(FileRegistry fileRegistry) {
this.fileRegistry = fileRegistry;
}
+
public void add(OnnxModel model) {
model.validate();
model.register(fileRegistry);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index ad85f68cb8a..8291c69af2f 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -130,6 +130,10 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
}
}
+ public OnnxModels getOnnxModels() {
+ return onnxModels;
+ }
+
public Map<String, RawRankProfile> getRankProfiles() {
return rankProfiles;
}
@@ -182,6 +186,13 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
modelBuilder.fileref(model.getFileReference());
model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
+ if (model.getStatelessExecutionMode().isPresent())
+ modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get());
+ if (model.getStatelessInterOpThreads().isPresent())
+ modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get());
+ if (model.getStatelessIntraOpThreads().isPresent())
+ modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get());
+
builder.model(modelBuilder);
}
}