aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-09-30 16:05:03 +0200
committerLester Solbakken <lesters@oath.com>2021-09-30 16:05:03 +0200
commit387da217a9eb2a6f88b50f3608659a7d75c66aeb (patch)
treefbc82d6cb6add5d599425191cfc994c454fea4a6 /model-integration
parent2f5a11f868291b34a3aa2c28817b36c5d0ed3d52 (diff)
Add parsing of ONNX Runtime session options to services.xml
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java14
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java58
2 files changed, 66 insertions, 6 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index b782a79f14b..4c44fca8c79 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -26,14 +26,16 @@ public class OnnxEvaluator {
private final OrtSession session;
public OnnxEvaluator(String modelPath) {
+ this(modelPath, null);
+ }
+
+ public OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options) {
try {
+ if (options == null) {
+ options = new OnnxEvaluatorOptions();
+ }
environment = OrtEnvironment.getEnvironment();
- OrtSession.SessionOptions options = new OrtSession.SessionOptions();
- options.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
- options.setIntraOpNumThreads(Math.max(1, Runtime.getRuntime().availableProcessors() / 4));
- options.setInterOpNumThreads(1);
- options.setExecutionMode(OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL);
- session = environment.createSession(modelPath, options);
+ session = environment.createSession(modelPath, options.getOptions());
} catch (OrtException e) {
throw new RuntimeException("ONNX Runtime exception", e);
}
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
new file mode 100644
index 00000000000..8467040e5c0
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
@@ -0,0 +1,58 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.modelintegration.evaluator;
+
+import ai.onnxruntime.OrtEnvironment;
+import ai.onnxruntime.OrtException;
+import ai.onnxruntime.OrtSession;
+
+/**
+ * Session options for ONNX Runtime evaluation
+ *
+ * @author lesters
+ */
+public class OnnxEvaluatorOptions {
+
+ private OrtSession.SessionOptions.OptLevel optimizationLevel;
+ private OrtSession.SessionOptions.ExecutionMode executionMode;
+ private int interOpThreads;
+ private int intraOpThreads;
+
+ public OnnxEvaluatorOptions() {
+ // Defaults:
+ optimizationLevel = OrtSession.SessionOptions.OptLevel.ALL_OPT;
+ executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
+ interOpThreads = 1;
+ intraOpThreads = Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / 4));
+ }
+
+ public OrtSession.SessionOptions getOptions() throws OrtException {
+ OrtSession.SessionOptions options = new OrtSession.SessionOptions();
+ options.setOptimizationLevel(optimizationLevel);
+ options.setExecutionMode(executionMode);
+ options.setInterOpNumThreads(interOpThreads);
+ options.setIntraOpNumThreads(intraOpThreads);
+ return options;
+ }
+
+ public void setExecutionMode(String mode) {
+ if ("parallel".equalsIgnoreCase(mode)) {
+ executionMode = OrtSession.SessionOptions.ExecutionMode.PARALLEL;
+ } else if ("sequential".equalsIgnoreCase(mode)) {
+ executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
+ }
+ }
+
+ public void setInterOpThreads(int threads) {
+ if (threads >= 0) {
+ interOpThreads = threads;
+ }
+ }
+
+ public void setIntraOpThreads(int threads) {
+ if (threads >= 0) {
+ intraOpThreads = threads;
+ }
+ }
+
+}