aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-09-30 16:05:03 +0200
committerLester Solbakken <lesters@oath.com>2021-09-30 16:05:03 +0200
commit387da217a9eb2a6f88b50f3608659a7d75c66aeb (patch)
treefbc82d6cb6add5d599425191cfc994c454fea4a6
parent2f5a11f868291b34a3aa2c28817b36c5d0ed3d52 (diff)
Add parsing of ONNX Runtime session options to services.xml
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java38
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java1
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java11
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java23
-rw-r--r--config-model/src/main/resources/schema/containercluster.rnc11
-rw-r--r--config-model/src/test/cfg/application/onnx/services.xml14
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java4
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java7
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java9
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java14
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java58
-rw-r--r--searchcore/src/vespa/searchcore/config/onnx-models.def17
12 files changed, 189 insertions, 18 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
index 3c42987512b..d85c0065d84 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
@@ -8,6 +8,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
+import java.util.Optional;
/**
* A global ONNX model distributed using file distribution, similar to ranking constants.
@@ -20,6 +21,10 @@ public class OnnxModel extends DistributableResource {
private Map<String, String> inputMap = new HashMap<>();
private Map<String, String> outputMap = new HashMap<>();
+ private String statelessExecutionMode = null;
+ private Integer statelessInterOpThreads = null;
+ private Integer statelessIntraOpThreads = null;
+
public OnnxModel(String name) {
super(name);
}
@@ -79,4 +84,37 @@ public class OnnxModel extends DistributableResource {
TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) {
return modelInfo != null ? modelInfo.getTensorType(onnxName, inputTypes) : TensorType.empty;
}
+
+ public void setStatelessExecutionMode(String executionMode) {
+ if ("parallel".equalsIgnoreCase(executionMode)) {
+ this.statelessExecutionMode = "parallel";
+ } else if ("sequential".equalsIgnoreCase(executionMode)) {
+ this.statelessExecutionMode = "sequential";
+ }
+ }
+
+ public Optional<String> getStatelessExecutionMode() {
+ return Optional.ofNullable(statelessExecutionMode);
+ }
+
+ public void setStatelessInterOpThreads(int interOpThreads) {
+ if (interOpThreads >= 0) {
+ this.statelessInterOpThreads = interOpThreads;
+ }
+ }
+
+ public Optional<Integer> getStatelessInterOpThreads() {
+ return Optional.ofNullable(statelessInterOpThreads);
+ }
+
+ public void setStatelessIntraOpThreads(int intraOpThreads) {
+ if (intraOpThreads >= 0) {
+ this.statelessIntraOpThreads = intraOpThreads;
+ }
+ }
+
+ public Optional<Integer> getStatelessIntraOpThreads() {
+ return Optional.ofNullable(statelessIntraOpThreads);
+ }
+
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
index cb61b8c9cec..dcbea70a32b 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
@@ -22,6 +22,7 @@ public class OnnxModels {
public OnnxModels(FileRegistry fileRegistry) {
this.fileRegistry = fileRegistry;
}
+
public void add(OnnxModel model) {
model.validate();
model.register(fileRegistry);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index ad85f68cb8a..8291c69af2f 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -130,6 +130,10 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
}
}
+ public OnnxModels getOnnxModels() {
+ return onnxModels;
+ }
+
public Map<String, RawRankProfile> getRankProfiles() {
return rankProfiles;
}
@@ -182,6 +186,13 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
modelBuilder.fileref(model.getFileReference());
model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
+ if (model.getStatelessExecutionMode().isPresent())
+ modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get());
+ if (model.getStatelessInterOpThreads().isPresent())
+ modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get());
+ if (model.getStatelessIntraOpThreads().isPresent())
+ modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get());
+
builder.model(modelBuilder);
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
index c318180fd56..c62dee68b2d 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
@@ -33,6 +33,7 @@ import com.yahoo.config.provision.Zone;
import com.yahoo.container.logging.FileConnectionLog;
import com.yahoo.osgi.provider.model.ComponentModel;
import com.yahoo.search.rendering.RendererRegistry;
+import com.yahoo.searchdefinition.OnnxModel;
import com.yahoo.searchdefinition.derived.RankProfileList;
import com.yahoo.security.X509CertificateUtils;
import com.yahoo.text.XML;
@@ -557,9 +558,31 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> {
RankProfileList profiles =
context.vespaModel() != null ? context.vespaModel().rankProfileList() : RankProfileList.empty;
+
+ Element onnxElement = XML.getChild(modelEvaluationElement, "onnx");
+ Element modelsElement = XML.getChild(onnxElement, "models");
+ for (Element modelElement : XML.getChildren(modelsElement, "model") ) {
+ OnnxModel onnxModel = profiles.getOnnxModels().get(modelElement.getAttribute("name"));
+ if (onnxModel == null)
+ continue; // Skip if model is not found
+ onnxModel.setStatelessExecutionMode(getStringValue(modelElement, "execution-mode", null));
+ onnxModel.setStatelessInterOpThreads(getIntValue(modelElement, "interop-threads", -1));
+ onnxModel.setStatelessIntraOpThreads(getIntValue(modelElement, "intraop-threads", -1));
+ }
+
cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles));
}
+ private String getStringValue(Element element, String name, String defaultValue) {
+ Element child = XML.getChild(element, name);
+ return (child != null) ? child.getTextContent() : defaultValue;
+ }
+
+ private int getIntValue(Element element, String name, int defaultValue) {
+ Element child = XML.getChild(element, name);
+ return (child != null) ? Integer.parseInt(child.getTextContent()) : defaultValue;
+ }
+
protected void addModelEvaluationBundles(ApplicationContainerCluster cluster) {
/* These bundles are added to all application container clusters, even if they haven't
* declared 'model-evaluation' in services.xml, because there are many public API packages
diff --git a/config-model/src/main/resources/schema/containercluster.rnc b/config-model/src/main/resources/schema/containercluster.rnc
index 992689a2189..945deec9f91 100644
--- a/config-model/src/main/resources/schema/containercluster.rnc
+++ b/config-model/src/main/resources/schema/containercluster.rnc
@@ -105,7 +105,16 @@ ZooKeeper = element zookeeper {
}
ModelEvaluation = element model-evaluation {
- empty
+ element onnx {
+ element models {
+ element model {
+ attribute name { string } &
+ element intraop-threads { xsd:nonNegativeInteger }? &
+ element interop-threads { xsd:nonNegativeInteger }? &
+ element execution-mode { string "sequential" | string "parallel" }?
+ }*
+ }?
+ }?
}
Ssl = element ssl {
diff --git a/config-model/src/test/cfg/application/onnx/services.xml b/config-model/src/test/cfg/application/onnx/services.xml
index 8731558c6f7..1cae9250009 100644
--- a/config-model/src/test/cfg/application/onnx/services.xml
+++ b/config-model/src/test/cfg/application/onnx/services.xml
@@ -3,7 +3,19 @@
<services version="1.0">
<container version="1.0">
- <model-evaluation/>
+ <model-evaluation>
+ <onnx>
+ <models>
+ <model name="mul">
+ <intraop-threads>2</intraop-threads>
+ </model>
+ <model name="non-existent-model">
+ <interop-threads>400</interop-threads>
+ <execution-mode>parallel</execution-mode>
+ </model>
+ </models>
+ </onnx>
+ </model-evaluation>
<nodes>
<node hostalias="node1" />
</nodes>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java
index b369560be74..27f1fc26b15 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java
@@ -123,6 +123,10 @@ public class StatelessOnnxEvaluationTest {
Tensor output = evaluator.bind("input1", input1).bind("input2", input2).evaluate();
assertEquals(6.0, output.sum().asDouble(), 1e-9);
+ OnnxModelsConfig.Model mulModel = onnxModelsConfig.model().get(0);
+ assertEquals(2, mulModel.stateless_intraop_threads());
+ assertEquals(-1, mulModel.stateless_interop_threads());
+ assertEquals("", mulModel.stateless_execution_mode());
}
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
index dc27c43ef70..b014f60095e 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
@@ -2,6 +2,7 @@
package ai.vespa.models.evaluation;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -17,12 +18,14 @@ class OnnxModel {
private final String name;
private final File modelFile;
+ private final OnnxEvaluatorOptions options;
private OnnxEvaluator evaluator;
- OnnxModel(String name, File modelFile) {
+ OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options) {
this.name = name;
this.modelFile = modelFile;
+ this.options = options;
}
public String name() {
@@ -31,7 +34,7 @@ class OnnxModel {
public void load() {
if (evaluator == null) {
- evaluator = new OnnxEvaluator(modelFile.getPath());
+ evaluator = new OnnxEvaluator(modelFile.getPath(), options);
}
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index fbfd34814ac..335c39e02a1 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
+import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import com.yahoo.collections.Pair;
import com.yahoo.config.FileReference;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
@@ -182,7 +183,13 @@ public class RankProfilesConfigImporter {
try {
String name = onnxModelConfig.name();
File file = fileAcquirer.waitFor(onnxModelConfig.fileref(), 7, TimeUnit.DAYS);
- return new OnnxModel(name, file);
+
+ OnnxEvaluatorOptions options = new OnnxEvaluatorOptions();
+ options.setExecutionMode(onnxModelConfig.stateless_execution_mode());
+ options.setInterOpThreads(onnxModelConfig.stateless_interop_threads());
+ options.setIntraOpThreads(onnxModelConfig.stateless_intraop_threads());
+
+ return new OnnxModel(name, file, options);
} catch (InterruptedException e) {
throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name());
}
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index b782a79f14b..4c44fca8c79 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -26,14 +26,16 @@ public class OnnxEvaluator {
private final OrtSession session;
public OnnxEvaluator(String modelPath) {
+ this(modelPath, null);
+ }
+
+ public OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options) {
try {
+ if (options == null) {
+ options = new OnnxEvaluatorOptions();
+ }
environment = OrtEnvironment.getEnvironment();
- OrtSession.SessionOptions options = new OrtSession.SessionOptions();
- options.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
- options.setIntraOpNumThreads(Math.max(1, Runtime.getRuntime().availableProcessors() / 4));
- options.setInterOpNumThreads(1);
- options.setExecutionMode(OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL);
- session = environment.createSession(modelPath, options);
+ session = environment.createSession(modelPath, options.getOptions());
} catch (OrtException e) {
throw new RuntimeException("ONNX Runtime exception", e);
}
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
new file mode 100644
index 00000000000..8467040e5c0
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
@@ -0,0 +1,58 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.modelintegration.evaluator;
+
+import ai.onnxruntime.OrtEnvironment;
+import ai.onnxruntime.OrtException;
+import ai.onnxruntime.OrtSession;
+
+/**
+ * Session options for ONNX Runtime evaluation
+ *
+ * @author lesters
+ */
+public class OnnxEvaluatorOptions {
+
+ private OrtSession.SessionOptions.OptLevel optimizationLevel;
+ private OrtSession.SessionOptions.ExecutionMode executionMode;
+ private int interOpThreads;
+ private int intraOpThreads;
+
+ public OnnxEvaluatorOptions() {
+ // Defaults:
+ optimizationLevel = OrtSession.SessionOptions.OptLevel.ALL_OPT;
+ executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
+ interOpThreads = 1;
+ intraOpThreads = Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / 4));
+ }
+
+ public OrtSession.SessionOptions getOptions() throws OrtException {
+ OrtSession.SessionOptions options = new OrtSession.SessionOptions();
+ options.setOptimizationLevel(optimizationLevel);
+ options.setExecutionMode(executionMode);
+ options.setInterOpNumThreads(interOpThreads);
+ options.setIntraOpNumThreads(intraOpThreads);
+ return options;
+ }
+
+ public void setExecutionMode(String mode) {
+ if ("parallel".equalsIgnoreCase(mode)) {
+ executionMode = OrtSession.SessionOptions.ExecutionMode.PARALLEL;
+ } else if ("sequential".equalsIgnoreCase(mode)) {
+ executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
+ }
+ }
+
+ public void setInterOpThreads(int threads) {
+ if (threads >= 0) {
+ interOpThreads = threads;
+ }
+ }
+
+ public void setIntraOpThreads(int threads) {
+ if (threads >= 0) {
+ intraOpThreads = threads;
+ }
+ }
+
+}
diff --git a/searchcore/src/vespa/searchcore/config/onnx-models.def b/searchcore/src/vespa/searchcore/config/onnx-models.def
index 33ea90002c8..2c87a5a78ad 100644
--- a/searchcore/src/vespa/searchcore/config/onnx-models.def
+++ b/searchcore/src/vespa/searchcore/config/onnx-models.def
@@ -1,10 +1,13 @@
# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
namespace=vespa.config.search.core
-model[].name string
-model[].fileref file
-model[].input[].name string
-model[].input[].source string
-model[].output[].name string
-model[].output[].as string
-model[].dry_run_on_setup bool default=false
+model[].name string
+model[].fileref file
+model[].input[].name string
+model[].input[].source string
+model[].output[].name string
+model[].output[].as string
+model[].dry_run_on_setup bool default=false
+model[].stateless_execution_mode string default=""
+model[].stateless_interop_threads int default=-1
+model[].stateless_intraop_threads int default=-1