summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2023-01-23 14:34:35 +0100
committerMartin Polden <mpolden@mpolden.no>2023-01-23 14:34:35 +0100
commitc9ba67dd5b8402125ea84a5c5fd12562ca7ebd15 (patch)
tree415a4d8ba38f8a9b52cd158f1540d19cf337bccf /model-integration
parent4ffc7ef27da2a8d824c9f04a80f89e569e36b322 (diff)
Support configuration of GPU device to use in ONNX model
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java25
1 files changed, 24 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
index 8467040e5c0..fceb63e6ae6 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
@@ -2,7 +2,6 @@
package ai.vespa.modelintegration.evaluator;
-import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
@@ -17,6 +16,8 @@ public class OnnxEvaluatorOptions {
private OrtSession.SessionOptions.ExecutionMode executionMode;
private int interOpThreads;
private int intraOpThreads;
+ private int gpuDeviceNumber;
+ private boolean gpuDeviceRequired;
public OnnxEvaluatorOptions() {
// Defaults:
@@ -24,6 +25,8 @@ public class OnnxEvaluatorOptions {
executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
interOpThreads = 1;
intraOpThreads = Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / 4));
+ gpuDeviceNumber = -1;
+ gpuDeviceRequired = false;
}
public OrtSession.SessionOptions getOptions() throws OrtException {
@@ -32,9 +35,24 @@ public class OnnxEvaluatorOptions {
options.setExecutionMode(executionMode);
options.setInterOpNumThreads(interOpThreads);
options.setIntraOpNumThreads(intraOpThreads);
+ addCuda(options);
return options;
}
+ private void addCuda(OrtSession.SessionOptions options) throws OrtException {
+ if (gpuDeviceNumber < 0) return;
+ try {
+ options.addCUDA(gpuDeviceNumber);
+ } catch (OrtException e) {
+ if (e.getCode() != OrtException.OrtErrorCode.ORT_EP_FAIL) {
+ throw e;
+ }
+ if (gpuDeviceRequired) {
+ throw new IllegalArgumentException("GPU device " + gpuDeviceNumber + " is required, but CUDA backend could not be initialized", e);
+ }
+ }
+ }
+
public void setExecutionMode(String mode) {
if ("parallel".equalsIgnoreCase(mode)) {
executionMode = OrtSession.SessionOptions.ExecutionMode.PARALLEL;
@@ -55,4 +73,9 @@ public class OnnxEvaluatorOptions {
}
}
+ public void setGpuDevice(int deviceNumber, boolean required) {
+ this.gpuDeviceNumber = deviceNumber;
+ this.gpuDeviceRequired = required;
+ }
+
}