diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition')
3 files changed, 50 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index 3c42987512b..d85c0065d84 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -8,6 +8,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Objects; +import java.util.Optional; /** * A global ONNX model distributed using file distribution, similar to ranking constants. @@ -20,6 +21,10 @@ public class OnnxModel extends DistributableResource { private Map<String, String> inputMap = new HashMap<>(); private Map<String, String> outputMap = new HashMap<>(); + private String statelessExecutionMode = null; + private Integer statelessInterOpThreads = null; + private Integer statelessIntraOpThreads = null; + public OnnxModel(String name) { super(name); } @@ -79,4 +84,37 @@ public class OnnxModel extends DistributableResource { TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) { return modelInfo != null ? modelInfo.getTensorType(onnxName, inputTypes) : TensorType.empty; } + + public void setStatelessExecutionMode(String executionMode) { + if ("parallel".equalsIgnoreCase(executionMode)) { + this.statelessExecutionMode = "parallel"; + } else if ("sequential".equalsIgnoreCase(executionMode)) { + this.statelessExecutionMode = "sequential"; + } + } + + public Optional<String> getStatelessExecutionMode() { + return Optional.ofNullable(statelessExecutionMode); + } + + public void setStatelessInterOpThreads(int interOpThreads) { + if (interOpThreads >= 0) { + this.statelessInterOpThreads = interOpThreads; + } + } + + public Optional<Integer> getStatelessInterOpThreads() { + return Optional.ofNullable(statelessInterOpThreads); + } + + public void setStatelessIntraOpThreads(int intraOpThreads) { + if (intraOpThreads >= 0) { + this.statelessIntraOpThreads = intraOpThreads; + } + } + + public Optional<Integer> getStatelessIntraOpThreads() { + return Optional.ofNullable(statelessIntraOpThreads); + } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java index cb61b8c9cec..dcbea70a32b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java @@ -22,6 +22,7 @@ public class OnnxModels { public OnnxModels(FileRegistry fileRegistry) { this.fileRegistry = fileRegistry; } + public void add(OnnxModel model) { model.validate(); model.register(fileRegistry); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index ad85f68cb8a..8291c69af2f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -130,6 +130,10 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ } } + public OnnxModels getOnnxModels() { + return onnxModels; + } + public Map<String, RawRankProfile> getRankProfiles() { return rankProfiles; } @@ -182,6 +186,13 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ modelBuilder.fileref(model.getFileReference()); model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source))); model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as))); + if (model.getStatelessExecutionMode().isPresent()) + modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get()); + if (model.getStatelessInterOpThreads().isPresent()) + modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get()); + if (model.getStatelessIntraOpThreads().isPresent()) + modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get()); + builder.model(modelBuilder); } } |