diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2023-04-19 14:34:51 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-19 14:34:51 +0200 |
commit | dd2f89843c95423c243ee858ce07c8f9554bab54 (patch) | |
tree | b90515ff9137a2204f1559e4063fa9d7abd138b6 | |
parent | 1b0363508a5a14e8d9d39bb7421ed1040b1a2a6a (diff) | |
parent | f0fb03bfe04c85fb50e21bbfaffb85cc5cd00c9a (diff) |
Merge pull request #26753 from vespa-engine/bjorncs/global-phase
Use quarter vcpu by default if execution mode is parallel
-rw-r--r-- | model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java index 1ed219a8560..a980ca984ec 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java @@ -7,6 +7,9 @@ import ai.onnxruntime.OrtSession; import java.util.Objects; +import static ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode.PARALLEL; +import static ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL; + /** * Session options for ONNX Runtime evaluation * @@ -24,9 +27,10 @@ public class OnnxEvaluatorOptions { public OnnxEvaluatorOptions() { // Defaults: optimizationLevel = OrtSession.SessionOptions.OptLevel.ALL_OPT; - executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL; - interOpThreads = 1; - intraOpThreads = Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / 4)); + executionMode = SEQUENTIAL; + int quarterVcpu = Math.max(1, (int) Math.ceil(Runtime.getRuntime().availableProcessors() / 4d)); + interOpThreads = quarterVcpu; + intraOpThreads = quarterVcpu; gpuDeviceNumber = -1; gpuDeviceRequired = false; } @@ -35,7 +39,7 @@ public class OnnxEvaluatorOptions { OrtSession.SessionOptions options = new OrtSession.SessionOptions(); options.setOptimizationLevel(optimizationLevel); options.setExecutionMode(executionMode); - options.setInterOpNumThreads(interOpThreads); + options.setInterOpNumThreads(executionMode == PARALLEL ? interOpThreads : 1); options.setIntraOpNumThreads(intraOpThreads); if (loadCuda) { options.addCUDA(gpuDeviceNumber); @@ -47,7 +51,7 @@ public class OnnxEvaluatorOptions { if ("parallel".equalsIgnoreCase(mode)) { executionMode = OrtSession.SessionOptions.ExecutionMode.PARALLEL; } else if ("sequential".equalsIgnoreCase(mode)) { - executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL; + executionMode = SEQUENTIAL; } } |