aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-09-30 16:05:03 +0200
committerLester Solbakken <lesters@oath.com>2021-09-30 16:05:03 +0200
commit387da217a9eb2a6f88b50f3608659a7d75c66aeb (patch)
treefbc82d6cb6add5d599425191cfc994c454fea4a6 /model-evaluation
parent2f5a11f868291b34a3aa2c28817b36c5d0ed3d52 (diff)
Add parsing of ONNX Runtime session options to services.xml
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java7
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java9
2 files changed, 13 insertions, 3 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
index dc27c43ef70..b014f60095e 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
@@ -2,6 +2,7 @@
package ai.vespa.models.evaluation;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -17,12 +18,14 @@ class OnnxModel {
private final String name;
private final File modelFile;
+ private final OnnxEvaluatorOptions options;
private OnnxEvaluator evaluator;
- OnnxModel(String name, File modelFile) {
+ OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options) {
this.name = name;
this.modelFile = modelFile;
+ this.options = options;
}
public String name() {
@@ -31,7 +34,7 @@ class OnnxModel {
public void load() {
if (evaluator == null) {
- evaluator = new OnnxEvaluator(modelFile.getPath());
+ evaluator = new OnnxEvaluator(modelFile.getPath(), options);
}
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index fbfd34814ac..335c39e02a1 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
+import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import com.yahoo.collections.Pair;
import com.yahoo.config.FileReference;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
@@ -182,7 +183,13 @@ public class RankProfilesConfigImporter {
try {
String name = onnxModelConfig.name();
File file = fileAcquirer.waitFor(onnxModelConfig.fileref(), 7, TimeUnit.DAYS);
- return new OnnxModel(name, file);
+
+ OnnxEvaluatorOptions options = new OnnxEvaluatorOptions();
+ options.setExecutionMode(onnxModelConfig.stateless_execution_mode());
+ options.setInterOpThreads(onnxModelConfig.stateless_interop_threads());
+ options.setIntraOpThreads(onnxModelConfig.stateless_intraop_threads());
+
+ return new OnnxModel(name, file, options);
} catch (InterruptedException e) {
throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name());
}