aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2023-04-19 14:34:51 +0200
committerGitHub <noreply@github.com>2023-04-19 14:34:51 +0200
commitdd2f89843c95423c243ee858ce07c8f9554bab54 (patch)
treeb90515ff9137a2204f1559e4063fa9d7abd138b6
parent1b0363508a5a14e8d9d39bb7421ed1040b1a2a6a (diff)
parentf0fb03bfe04c85fb50e21bbfaffb85cc5cd00c9a (diff)
Merge pull request #26753 from vespa-engine/bjorncs/global-phase
Use quarter vcpu by default if execution mode is parallel
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java14
1 files changed, 9 insertions, 5 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
index 1ed219a8560..a980ca984ec 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
@@ -7,6 +7,9 @@ import ai.onnxruntime.OrtSession;
import java.util.Objects;
+import static ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode.PARALLEL;
+import static ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
+
/**
* Session options for ONNX Runtime evaluation
*
@@ -24,9 +27,10 @@ public class OnnxEvaluatorOptions {
public OnnxEvaluatorOptions() {
// Defaults:
optimizationLevel = OrtSession.SessionOptions.OptLevel.ALL_OPT;
- executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
- interOpThreads = 1;
- intraOpThreads = Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / 4));
+ executionMode = SEQUENTIAL;
+ int quarterVcpu = Math.max(1, (int) Math.ceil(Runtime.getRuntime().availableProcessors() / 4d));
+ interOpThreads = quarterVcpu;
+ intraOpThreads = quarterVcpu;
gpuDeviceNumber = -1;
gpuDeviceRequired = false;
}
@@ -35,7 +39,7 @@ public class OnnxEvaluatorOptions {
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
options.setOptimizationLevel(optimizationLevel);
options.setExecutionMode(executionMode);
- options.setInterOpNumThreads(interOpThreads);
+ options.setInterOpNumThreads(executionMode == PARALLEL ? interOpThreads : 1);
options.setIntraOpNumThreads(intraOpThreads);
if (loadCuda) {
options.addCUDA(gpuDeviceNumber);
@@ -47,7 +51,7 @@ public class OnnxEvaluatorOptions {
if ("parallel".equalsIgnoreCase(mode)) {
executionMode = OrtSession.SessionOptions.ExecutionMode.PARALLEL;
} else if ("sequential".equalsIgnoreCase(mode)) {
- executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
+ executionMode = SEQUENTIAL;
}
}