diff options
author | Lester Solbakken <lesters@oath.com> | 2021-09-30 16:05:03 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-09-30 16:05:03 +0200 |
commit | 387da217a9eb2a6f88b50f3608659a7d75c66aeb (patch) | |
tree | fbc82d6cb6add5d599425191cfc994c454fea4a6 /model-evaluation/src | |
parent | 2f5a11f868291b34a3aa2c28817b36c5d0ed3d52 (diff) |
Add parsing of ONNX Runtime session options to services.xml
Diffstat (limited to 'model-evaluation/src')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java | 7 | ||||
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java | 9 |
2 files changed, 13 insertions, 3 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java index dc27c43ef70..b014f60095e 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java @@ -2,6 +2,7 @@ package ai.vespa.models.evaluation; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -17,12 +18,14 @@ class OnnxModel { private final String name; private final File modelFile; + private final OnnxEvaluatorOptions options; private OnnxEvaluator evaluator; - OnnxModel(String name, File modelFile) { + OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options) { this.name = name; this.modelFile = modelFile; + this.options = options; } public String name() { @@ -31,7 +34,7 @@ class OnnxModel { public void load() { if (evaluator == null) { - evaluator = new OnnxEvaluator(modelFile.getPath()); + evaluator = new OnnxEvaluator(modelFile.getPath(), options); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index fbfd34814ac..335c39e02a1 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; import com.yahoo.collections.Pair; import com.yahoo.config.FileReference; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; @@ -182,7 +183,13 @@ public class RankProfilesConfigImporter { try { String name = onnxModelConfig.name(); File file = fileAcquirer.waitFor(onnxModelConfig.fileref(), 7, TimeUnit.DAYS); - return new OnnxModel(name, file); + + OnnxEvaluatorOptions options = new OnnxEvaluatorOptions(); + options.setExecutionMode(onnxModelConfig.stateless_execution_mode()); + options.setInterOpThreads(onnxModelConfig.stateless_interop_threads()); + options.setIntraOpThreads(onnxModelConfig.stateless_intraop_threads()); + + return new OnnxModel(name, file, options); } catch (InterruptedException e) { throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name()); } |