aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-04-17 13:14:54 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-04-17 13:16:07 +0200
commitf0fb03bfe04c85fb50e21bbfaffb85cc5cd00c9a (patch)
tree0eb1ae0e6e297a5e280e41d72847b9373b7c3c14
parentee613a99dc15b6acaaf923c60d76fe9428c0aee8 (diff)
Use quarter vcpu by default if execution mode is parallel
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java14
1 files changed, 9 insertions, 5 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
index 1ed219a8560..a980ca984ec 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
@@ -7,6 +7,9 @@ import ai.onnxruntime.OrtSession;
import java.util.Objects;
+import static ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode.PARALLEL;
+import static ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
+
/**
* Session options for ONNX Runtime evaluation
*
@@ -24,9 +27,10 @@ public class OnnxEvaluatorOptions {
public OnnxEvaluatorOptions() {
// Defaults:
optimizationLevel = OrtSession.SessionOptions.OptLevel.ALL_OPT;
- executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
- interOpThreads = 1;
- intraOpThreads = Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / 4));
+ executionMode = SEQUENTIAL;
+ int quarterVcpu = Math.max(1, (int) Math.ceil(Runtime.getRuntime().availableProcessors() / 4d));
+ interOpThreads = quarterVcpu;
+ intraOpThreads = quarterVcpu;
gpuDeviceNumber = -1;
gpuDeviceRequired = false;
}
@@ -35,7 +39,7 @@ public class OnnxEvaluatorOptions {
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
options.setOptimizationLevel(optimizationLevel);
options.setExecutionMode(executionMode);
- options.setInterOpNumThreads(interOpThreads);
+ options.setInterOpNumThreads(executionMode == PARALLEL ? interOpThreads : 1);
options.setIntraOpNumThreads(intraOpThreads);
if (loadCuda) {
options.addCUDA(gpuDeviceNumber);
@@ -47,7 +51,7 @@ public class OnnxEvaluatorOptions {
if ("parallel".equalsIgnoreCase(mode)) {
executionMode = OrtSession.SessionOptions.ExecutionMode.PARALLEL;
} else if ("sequential".equalsIgnoreCase(mode)) {
- executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
+ executionMode = SEQUENTIAL;
}
}