summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2021-10-01 09:30:35 +0200
committerGitHub <noreply@github.com>2021-10-01 09:30:35 +0200
commitba6fb61c10b2ea9719e1f375d2efaf3182a8e257 (patch)
treeb911f0e3a3ba465209e101c2fce7bf70a481e86e /config-model
parentb8341539baf9c5ccc933e161afbe9facd7eb87ca (diff)
parent387da217a9eb2a6f88b50f3608659a7d75c66aeb (diff)
Merge pull request #19380 from vespa-engine/lesters/stateless-eval-options
Add parsing of ONNX Runtime session options to services.xml
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java38
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java1
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java11
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java23
-rw-r--r--config-model/src/main/resources/schema/containercluster.rnc11
-rw-r--r--config-model/src/test/cfg/application/onnx/services.xml14
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java4
7 files changed, 100 insertions, 2 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
index 3c42987512b..d85c0065d84 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
@@ -8,6 +8,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
+import java.util.Optional;
/**
* A global ONNX model distributed using file distribution, similar to ranking constants.
@@ -20,6 +21,10 @@ public class OnnxModel extends DistributableResource {
private Map<String, String> inputMap = new HashMap<>();
private Map<String, String> outputMap = new HashMap<>();
+ private String statelessExecutionMode = null;
+ private Integer statelessInterOpThreads = null;
+ private Integer statelessIntraOpThreads = null;
+
public OnnxModel(String name) {
super(name);
}
@@ -79,4 +84,37 @@ public class OnnxModel extends DistributableResource {
TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) {
return modelInfo != null ? modelInfo.getTensorType(onnxName, inputTypes) : TensorType.empty;
}
+
+ public void setStatelessExecutionMode(String executionMode) {
+ if ("parallel".equalsIgnoreCase(executionMode)) {
+ this.statelessExecutionMode = "parallel";
+ } else if ("sequential".equalsIgnoreCase(executionMode)) {
+ this.statelessExecutionMode = "sequential";
+ }
+ }
+
+ public Optional<String> getStatelessExecutionMode() {
+ return Optional.ofNullable(statelessExecutionMode);
+ }
+
+ public void setStatelessInterOpThreads(int interOpThreads) {
+ if (interOpThreads >= 0) {
+ this.statelessInterOpThreads = interOpThreads;
+ }
+ }
+
+ public Optional<Integer> getStatelessInterOpThreads() {
+ return Optional.ofNullable(statelessInterOpThreads);
+ }
+
+ public void setStatelessIntraOpThreads(int intraOpThreads) {
+ if (intraOpThreads >= 0) {
+ this.statelessIntraOpThreads = intraOpThreads;
+ }
+ }
+
+ public Optional<Integer> getStatelessIntraOpThreads() {
+ return Optional.ofNullable(statelessIntraOpThreads);
+ }
+
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
index cb61b8c9cec..dcbea70a32b 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
@@ -22,6 +22,7 @@ public class OnnxModels {
public OnnxModels(FileRegistry fileRegistry) {
this.fileRegistry = fileRegistry;
}
+
public void add(OnnxModel model) {
model.validate();
model.register(fileRegistry);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index ad85f68cb8a..8291c69af2f 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -130,6 +130,10 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
}
}
+ public OnnxModels getOnnxModels() {
+ return onnxModels;
+ }
+
public Map<String, RawRankProfile> getRankProfiles() {
return rankProfiles;
}
@@ -182,6 +186,13 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
modelBuilder.fileref(model.getFileReference());
model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
+ if (model.getStatelessExecutionMode().isPresent())
+ modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get());
+ if (model.getStatelessInterOpThreads().isPresent())
+ modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get());
+ if (model.getStatelessIntraOpThreads().isPresent())
+ modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get());
+
builder.model(modelBuilder);
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
index 87d79728fae..2622a9e50b7 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
@@ -33,6 +33,7 @@ import com.yahoo.config.provision.Zone;
import com.yahoo.container.logging.FileConnectionLog;
import com.yahoo.osgi.provider.model.ComponentModel;
import com.yahoo.search.rendering.RendererRegistry;
+import com.yahoo.searchdefinition.OnnxModel;
import com.yahoo.searchdefinition.derived.RankProfileList;
import com.yahoo.security.X509CertificateUtils;
import com.yahoo.text.XML;
@@ -559,9 +560,31 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> {
RankProfileList profiles =
context.vespaModel() != null ? context.vespaModel().rankProfileList() : RankProfileList.empty;
+
+ Element onnxElement = XML.getChild(modelEvaluationElement, "onnx");
+ Element modelsElement = XML.getChild(onnxElement, "models");
+ for (Element modelElement : XML.getChildren(modelsElement, "model") ) {
+ OnnxModel onnxModel = profiles.getOnnxModels().get(modelElement.getAttribute("name"));
+ if (onnxModel == null)
+ continue; // Skip if model is not found
+ onnxModel.setStatelessExecutionMode(getStringValue(modelElement, "execution-mode", null));
+ onnxModel.setStatelessInterOpThreads(getIntValue(modelElement, "interop-threads", -1));
+ onnxModel.setStatelessIntraOpThreads(getIntValue(modelElement, "intraop-threads", -1));
+ }
+
cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles));
}
+ private String getStringValue(Element element, String name, String defaultValue) {
+ Element child = XML.getChild(element, name);
+ return (child != null) ? child.getTextContent() : defaultValue;
+ }
+
+ private int getIntValue(Element element, String name, int defaultValue) {
+ Element child = XML.getChild(element, name);
+ return (child != null) ? Integer.parseInt(child.getTextContent()) : defaultValue;
+ }
+
protected void addModelEvaluationBundles(ApplicationContainerCluster cluster) {
/* These bundles are added to all application container clusters, even if they haven't
* declared 'model-evaluation' in services.xml, because there are many public API packages
diff --git a/config-model/src/main/resources/schema/containercluster.rnc b/config-model/src/main/resources/schema/containercluster.rnc
index 992689a2189..945deec9f91 100644
--- a/config-model/src/main/resources/schema/containercluster.rnc
+++ b/config-model/src/main/resources/schema/containercluster.rnc
@@ -105,7 +105,16 @@ ZooKeeper = element zookeeper {
}
ModelEvaluation = element model-evaluation {
- empty
+ element onnx {
+ element models {
+ element model {
+ attribute name { string } &
+ element intraop-threads { xsd:nonNegativeInteger }? &
+ element interop-threads { xsd:nonNegativeInteger }? &
+ element execution-mode { string "sequential" | string "parallel" }?
+ }*
+ }?
+ }?
}
Ssl = element ssl {
diff --git a/config-model/src/test/cfg/application/onnx/services.xml b/config-model/src/test/cfg/application/onnx/services.xml
index 8731558c6f7..1cae9250009 100644
--- a/config-model/src/test/cfg/application/onnx/services.xml
+++ b/config-model/src/test/cfg/application/onnx/services.xml
@@ -3,7 +3,19 @@
<services version="1.0">
<container version="1.0">
- <model-evaluation/>
+ <model-evaluation>
+ <onnx>
+ <models>
+ <model name="mul">
+ <intraop-threads>2</intraop-threads>
+ </model>
+ <model name="non-existent-model">
+ <interop-threads>400</interop-threads>
+ <execution-mode>parallel</execution-mode>
+ </model>
+ </models>
+ </onnx>
+ </model-evaluation>
<nodes>
<node hostalias="node1" />
</nodes>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java
index b369560be74..27f1fc26b15 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java
@@ -123,6 +123,10 @@ public class StatelessOnnxEvaluationTest {
Tensor output = evaluator.bind("input1", input1).bind("input2", input2).evaluate();
assertEquals(6.0, output.sum().asDouble(), 1e-9);
+ OnnxModelsConfig.Model mulModel = onnxModelsConfig.model().get(0);
+ assertEquals(2, mulModel.stateless_intraop_threads());
+ assertEquals(-1, mulModel.stateless_interop_threads());
+ assertEquals("", mulModel.stateless_execution_mode());
}
}