diff options
author | Martin Polden <mpolden@mpolden.no> | 2023-01-23 14:34:35 +0100 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2023-01-23 14:34:35 +0100 |
commit | c9ba67dd5b8402125ea84a5c5fd12562ca7ebd15 (patch) | |
tree | 415a4d8ba38f8a9b52cd158f1540d19cf337bccf /model-integration/src/main/java/ai | |
parent | 4ffc7ef27da2a8d824c9f04a80f89e569e36b322 (diff) |
Support configuration of GPU device to use in ONNX model
Diffstat (limited to 'model-integration/src/main/java/ai')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java index 8467040e5c0..fceb63e6ae6 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java @@ -2,7 +2,6 @@ package ai.vespa.modelintegration.evaluator; -import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; @@ -17,6 +16,8 @@ public class OnnxEvaluatorOptions { private OrtSession.SessionOptions.ExecutionMode executionMode; private int interOpThreads; private int intraOpThreads; + private int gpuDeviceNumber; + private boolean gpuDeviceRequired; public OnnxEvaluatorOptions() { // Defaults: @@ -24,6 +25,8 @@ public class OnnxEvaluatorOptions { executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL; interOpThreads = 1; intraOpThreads = Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / 4)); + gpuDeviceNumber = -1; + gpuDeviceRequired = false; } public OrtSession.SessionOptions getOptions() throws OrtException { @@ -32,9 +35,24 @@ public class OnnxEvaluatorOptions { options.setExecutionMode(executionMode); options.setInterOpNumThreads(interOpThreads); options.setIntraOpNumThreads(intraOpThreads); + addCuda(options); return options; } + private void addCuda(OrtSession.SessionOptions options) throws OrtException { + if (gpuDeviceNumber < 0) return; + try { + options.addCUDA(gpuDeviceNumber); + } catch (OrtException e) { + if (e.getCode() != OrtException.OrtErrorCode.ORT_EP_FAIL) { + throw e; + } + if (gpuDeviceRequired) { + throw new IllegalArgumentException("GPU device " + gpuDeviceNumber + " is required, but CUDA backend could not be initialized", e); + } + } + } + public void setExecutionMode(String mode) { if ("parallel".equalsIgnoreCase(mode)) { executionMode = OrtSession.SessionOptions.ExecutionMode.PARALLEL; @@ -55,4 +73,9 @@ public class OnnxEvaluatorOptions { } } + public void setGpuDevice(int deviceNumber, boolean required) { + this.gpuDeviceNumber = deviceNumber; + this.gpuDeviceRequired = required; + } + } |