aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2023-02-07 16:04:57 +0100
committerMartin Polden <mpolden@mpolden.no>2023-02-08 09:52:53 +0100
commitc47f27f1c3362b459e276c59ebcd09ab259b710e (patch)
treef68963b9b91558d9418da0cb88e5a16d67a52034 /model-integration/src
parentd03bf2ed1b239f4998bdfd6580965bcd0a7d62a4 (diff)
Allow fallback to CPU if nodes are provisioned without GPU
Diffstat (limited to 'model-integration/src')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java25
2 files changed, 24 insertions, 13 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index ebed464421b..563ef911f8f 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -31,7 +31,7 @@ public class OnnxEvaluator {
public OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options) {
environment = OrtEnvironment.getEnvironment();
- session = createSession(modelPath, environment, options);
+ session = createSession(modelPath, environment, options, true);
}
public Tensor evaluate(Map<String, Tensor> inputs, String output) {
@@ -86,18 +86,22 @@ public class OnnxEvaluator {
}
}
- private static OrtSession createSession(String modelPath, OrtEnvironment environment, OnnxEvaluatorOptions options) {
+ private static OrtSession createSession(String modelPath, OrtEnvironment environment, OnnxEvaluatorOptions options, boolean tryCuda) {
if (options == null) {
options = new OnnxEvaluatorOptions();
}
try {
- return environment.createSession(modelPath, options.getOptions());
+ return environment.createSession(modelPath, options.getOptions(tryCuda && options.requestingGpu()));
} catch (OrtException e) {
if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) {
throw new IllegalArgumentException("No such file: " + modelPath);
}
+ if (tryCuda && isCudaError(e) && !options.gpuDeviceRequired()) {
+ // Failed in CUDA native code, but GPU device is optional, so we can proceed without it
+ return createSession(modelPath, environment, options, false);
+ }
if (isCudaError(e)) {
- throw new IllegalArgumentException("GPU device " + options.gpuDevice() + " requested, but CUDA initialization failed", e);
+ throw new IllegalArgumentException("GPU device is requested, but CUDA initialization failed", e);
}
throw new RuntimeException("ONNX Runtime exception", e);
}
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
index f838a3b3f7f..b6de9698f1a 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
@@ -16,7 +16,8 @@ public class OnnxEvaluatorOptions {
private OrtSession.SessionOptions.ExecutionMode executionMode;
private int interOpThreads;
private int intraOpThreads;
- private int gpuDevice;
+ private int gpuDeviceNumber;
+ private boolean gpuDeviceRequired;
public OnnxEvaluatorOptions() {
// Defaults:
@@ -24,17 +25,18 @@ public class OnnxEvaluatorOptions {
executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
interOpThreads = 1;
intraOpThreads = Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / 4));
- gpuDevice = -1;
+ gpuDeviceNumber = -1;
+ gpuDeviceRequired = false;
}
- public OrtSession.SessionOptions getOptions() throws OrtException {
+ public OrtSession.SessionOptions getOptions(boolean loadCuda) throws OrtException {
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
options.setOptimizationLevel(optimizationLevel);
options.setExecutionMode(executionMode);
options.setInterOpNumThreads(interOpThreads);
options.setIntraOpNumThreads(intraOpThreads);
- if (gpuDevice > -1) {
- options.addCUDA(gpuDevice);
+ if (loadCuda) {
+ options.addCUDA(gpuDeviceNumber);
}
return options;
}
@@ -59,12 +61,17 @@ public class OnnxEvaluatorOptions {
}
}
- public void setGpuDevice(int deviceNumber) {
- this.gpuDevice = deviceNumber;
+ public void setGpuDevice(int deviceNumber, boolean required) {
+ this.gpuDeviceNumber = deviceNumber;
+ this.gpuDeviceRequired = required;
}
- public int gpuDevice() {
- return gpuDevice;
+ public boolean requestingGpu() {
+ return gpuDeviceNumber > -1;
+ }
+
+ public boolean gpuDeviceRequired() {
+ return gpuDeviceRequired;
}
}