diff options
author | Lester Solbakken <lesters@oath.com> | 2021-09-30 16:05:03 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-09-30 16:05:03 +0200 |
commit | 387da217a9eb2a6f88b50f3608659a7d75c66aeb (patch) | |
tree | fbc82d6cb6add5d599425191cfc994c454fea4a6 /model-integration | |
parent | 2f5a11f868291b34a3aa2c28817b36c5d0ed3d52 (diff) |
Add parsing of ONNX Runtime session options to services.xml
Diffstat (limited to 'model-integration')
2 files changed, 66 insertions, 6 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java index b782a79f14b..4c44fca8c79 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java @@ -26,14 +26,16 @@ public class OnnxEvaluator { private final OrtSession session; public OnnxEvaluator(String modelPath) { + this(modelPath, null); + } + + public OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options) { try { + if (options == null) { + options = new OnnxEvaluatorOptions(); + } environment = OrtEnvironment.getEnvironment(); - OrtSession.SessionOptions options = new OrtSession.SessionOptions(); - options.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT); - options.setIntraOpNumThreads(Math.max(1, Runtime.getRuntime().availableProcessors() / 4)); - options.setInterOpNumThreads(1); - options.setExecutionMode(OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL); - session = environment.createSession(modelPath, options); + session = environment.createSession(modelPath, options.getOptions()); } catch (OrtException e) { throw new RuntimeException("ONNX Runtime exception", e); } diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java new file mode 100644 index 00000000000..8467040e5c0 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java @@ -0,0 +1,58 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.modelintegration.evaluator; + +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; + +/** + * Session options for ONNX Runtime evaluation + * + * @author lesters + */ +public class OnnxEvaluatorOptions { + + private OrtSession.SessionOptions.OptLevel optimizationLevel; + private OrtSession.SessionOptions.ExecutionMode executionMode; + private int interOpThreads; + private int intraOpThreads; + + public OnnxEvaluatorOptions() { + // Defaults: + optimizationLevel = OrtSession.SessionOptions.OptLevel.ALL_OPT; + executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL; + interOpThreads = 1; + intraOpThreads = Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / 4)); + } + + public OrtSession.SessionOptions getOptions() throws OrtException { + OrtSession.SessionOptions options = new OrtSession.SessionOptions(); + options.setOptimizationLevel(optimizationLevel); + options.setExecutionMode(executionMode); + options.setInterOpNumThreads(interOpThreads); + options.setIntraOpNumThreads(intraOpThreads); + return options; + } + + public void setExecutionMode(String mode) { + if ("parallel".equalsIgnoreCase(mode)) { + executionMode = OrtSession.SessionOptions.ExecutionMode.PARALLEL; + } else if ("sequential".equalsIgnoreCase(mode)) { + executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL; + } + } + + public void setInterOpThreads(int threads) { + if (threads >= 0) { + interOpThreads = threads; + } + } + + public void setIntraOpThreads(int threads) { + if (threads >= 0) { + intraOpThreads = threads; + } + } + +} |