diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java index b782a79f14b..4c44fca8c79 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java @@ -26,14 +26,16 @@ public class OnnxEvaluator { private final OrtSession session; public OnnxEvaluator(String modelPath) { + this(modelPath, null); + } + + public OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options) { try { + if (options == null) { + options = new OnnxEvaluatorOptions(); + } environment = OrtEnvironment.getEnvironment(); - OrtSession.SessionOptions options = new OrtSession.SessionOptions(); - options.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT); - options.setIntraOpNumThreads(Math.max(1, Runtime.getRuntime().availableProcessors() / 4)); - options.setInterOpNumThreads(1); - options.setExecutionMode(OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL); - session = environment.createSession(modelPath, options); + session = environment.createSession(modelPath, options.getOptions()); } catch (OrtException e) { throw new RuntimeException("ONNX Runtime exception", e); } |