diff options
Diffstat (limited to 'config-model/src/main')
5 files changed, 83 insertions, 1 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index 3c42987512b..d85c0065d84 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -8,6 +8,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Objects; +import java.util.Optional; /** * A global ONNX model distributed using file distribution, similar to ranking constants. @@ -20,6 +21,10 @@ public class OnnxModel extends DistributableResource { private Map<String, String> inputMap = new HashMap<>(); private Map<String, String> outputMap = new HashMap<>(); + private String statelessExecutionMode = null; + private Integer statelessInterOpThreads = null; + private Integer statelessIntraOpThreads = null; + public OnnxModel(String name) { super(name); } @@ -79,4 +84,37 @@ public class OnnxModel extends DistributableResource { TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) { return modelInfo != null ? modelInfo.getTensorType(onnxName, inputTypes) : TensorType.empty; } + + public void setStatelessExecutionMode(String executionMode) { + if ("parallel".equalsIgnoreCase(executionMode)) { + this.statelessExecutionMode = "parallel"; + } else if ("sequential".equalsIgnoreCase(executionMode)) { + this.statelessExecutionMode = "sequential"; + } + } + + public Optional<String> getStatelessExecutionMode() { + return Optional.ofNullable(statelessExecutionMode); + } + + public void setStatelessInterOpThreads(int interOpThreads) { + if (interOpThreads >= 0) { + this.statelessInterOpThreads = interOpThreads; + } + } + + public Optional<Integer> getStatelessInterOpThreads() { + return Optional.ofNullable(statelessInterOpThreads); + } + + public void setStatelessIntraOpThreads(int intraOpThreads) { + if (intraOpThreads >= 0) { + this.statelessIntraOpThreads = intraOpThreads; + } + } + + public Optional<Integer> getStatelessIntraOpThreads() { + return Optional.ofNullable(statelessIntraOpThreads); + } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java index cb61b8c9cec..dcbea70a32b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java @@ -22,6 +22,7 @@ public class OnnxModels { public OnnxModels(FileRegistry fileRegistry) { this.fileRegistry = fileRegistry; } + public void add(OnnxModel model) { model.validate(); model.register(fileRegistry); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index ad85f68cb8a..8291c69af2f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -130,6 +130,10 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ } } + public OnnxModels getOnnxModels() { + return onnxModels; + } + public Map<String, RawRankProfile> getRankProfiles() { return rankProfiles; } @@ -182,6 +186,13 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ modelBuilder.fileref(model.getFileReference()); model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source))); model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as))); + if (model.getStatelessExecutionMode().isPresent()) + modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get()); + if (model.getStatelessInterOpThreads().isPresent()) + modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get()); + if (model.getStatelessIntraOpThreads().isPresent()) + modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get()); + builder.model(modelBuilder); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index 87d79728fae..2622a9e50b7 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -33,6 +33,7 @@ import com.yahoo.config.provision.Zone; import com.yahoo.container.logging.FileConnectionLog; import com.yahoo.osgi.provider.model.ComponentModel; import com.yahoo.search.rendering.RendererRegistry; +import com.yahoo.searchdefinition.OnnxModel; import com.yahoo.searchdefinition.derived.RankProfileList; import com.yahoo.security.X509CertificateUtils; import com.yahoo.text.XML; @@ -559,9 +560,31 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { RankProfileList profiles = context.vespaModel() != null ? context.vespaModel().rankProfileList() : RankProfileList.empty; + + Element onnxElement = XML.getChild(modelEvaluationElement, "onnx"); + Element modelsElement = XML.getChild(onnxElement, "models"); + for (Element modelElement : XML.getChildren(modelsElement, "model") ) { + OnnxModel onnxModel = profiles.getOnnxModels().get(modelElement.getAttribute("name")); + if (onnxModel == null) + continue; // Skip if model is not found + onnxModel.setStatelessExecutionMode(getStringValue(modelElement, "execution-mode", null)); + onnxModel.setStatelessInterOpThreads(getIntValue(modelElement, "interop-threads", -1)); + onnxModel.setStatelessIntraOpThreads(getIntValue(modelElement, "intraop-threads", -1)); + } + cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles)); } + private String getStringValue(Element element, String name, String defaultValue) { + Element child = XML.getChild(element, name); + return (child != null) ? child.getTextContent() : defaultValue; + } + + private int getIntValue(Element element, String name, int defaultValue) { + Element child = XML.getChild(element, name); + return (child != null) ? Integer.parseInt(child.getTextContent()) : defaultValue; + } + protected void addModelEvaluationBundles(ApplicationContainerCluster cluster) { /* These bundles are added to all application container clusters, even if they haven't * declared 'model-evaluation' in services.xml, because there are many public API packages diff --git a/config-model/src/main/resources/schema/containercluster.rnc b/config-model/src/main/resources/schema/containercluster.rnc index 992689a2189..945deec9f91 100644 --- a/config-model/src/main/resources/schema/containercluster.rnc +++ b/config-model/src/main/resources/schema/containercluster.rnc @@ -105,7 +105,16 @@ ZooKeeper = element zookeeper { } ModelEvaluation = element model-evaluation { - empty + element onnx { + element models { + element model { + attribute name { string } & + element intraop-threads { xsd:nonNegativeInteger }? & + element interop-threads { xsd:nonNegativeInteger }? & + element execution-mode { string "sequential" | string "parallel" }? + }* + }? + }? } Ssl = element ssl { |