aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java14
1 files changed, 8 insertions, 6 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index b782a79f14b..4c44fca8c79 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -26,14 +26,16 @@ public class OnnxEvaluator {
private final OrtSession session;
public OnnxEvaluator(String modelPath) {
+ this(modelPath, null);
+ }
+
+ public OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options) {
try {
+ if (options == null) {
+ options = new OnnxEvaluatorOptions();
+ }
environment = OrtEnvironment.getEnvironment();
- OrtSession.SessionOptions options = new OrtSession.SessionOptions();
- options.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
- options.setIntraOpNumThreads(Math.max(1, Runtime.getRuntime().availableProcessors() / 4));
- options.setInterOpNumThreads(1);
- options.setExecutionMode(OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL);
- session = environment.createSession(modelPath, options);
+ session = environment.createSession(modelPath, options.getOptions());
} catch (OrtException e) {
throw new RuntimeException("ONNX Runtime exception", e);
}