aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2023-01-26 09:35:59 +0100
committerMartin Polden <mpolden@mpolden.no>2023-01-26 10:12:03 +0100
commit87f21c2a817e43e2709e2c05f2453f8bd75b737c (patch)
treed0c2568fecc5e73a8647b6bede9449674cde293b /model-integration/src
parentdb1bcbfe4768e787d8794c921caea47c3e7cf58f (diff)
Skip CUDA entirely if GPU device is optional
CUDA may fail after its library is loaded, e.g. when the session is created.
Diffstat (limited to 'model-integration/src')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java40
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java23
2 files changed, 39 insertions, 24 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index 125707c9aaa..bb40333f9b3 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -30,18 +30,8 @@ public class OnnxEvaluator {
}
public OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options) {
- try {
- if (options == null) {
- options = new OnnxEvaluatorOptions();
- }
- environment = OrtEnvironment.getEnvironment();
- session = environment.createSession(modelPath, options.getOptions());
- } catch (OrtException e) {
- if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) {
- throw new IllegalArgumentException("No such file: "+modelPath);
- }
- throw new RuntimeException("ONNX Runtime exception", e);
- }
+ environment = OrtEnvironment.getEnvironment();
+ session = createSession(modelPath, environment, options, true);
}
public Tensor evaluate(Map<String, Tensor> inputs, String output) {
@@ -96,6 +86,32 @@ public class OnnxEvaluator {
}
}
+ private static OrtSession createSession(String modelPath, OrtEnvironment environment, OnnxEvaluatorOptions options, boolean tryCuda) {
+ if (options == null) {
+ options = new OnnxEvaluatorOptions();
+ }
+ try {
+ return environment.createSession(modelPath, options.getOptions(tryCuda && options.hasGpuDevice()));
+ } catch (OrtException e) {
+ if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) {
+ throw new IllegalArgumentException("No such file: " + modelPath);
+ }
+ if (tryCuda && isCudaError(e) && !options.gpuDeviceRequired()) {
+ // Failed in CUDA native code, but GPU device is optional, so we can proceed without it
+ return createSession(modelPath, environment, options, false);
+ }
+ throw new RuntimeException("ONNX Runtime exception", e);
+ }
+ }
+
+ private static boolean isCudaError(OrtException e) {
+ return switch (e.getCode()) {
+ case ORT_FAIL -> e.getMessage().contains("cudaError");
+ case ORT_EP_FAIL -> e.getMessage().contains("Failed to find CUDA");
+ default -> false;
+ };
+ }
+
public static boolean isRuntimeAvailable() {
return isRuntimeAvailable("");
}
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
index cdbce760d92..9d82531df02 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
@@ -29,25 +29,16 @@ public class OnnxEvaluatorOptions {
gpuDeviceRequired = false;
}
- public OrtSession.SessionOptions getOptions() throws OrtException {
+ public OrtSession.SessionOptions getOptions(boolean loadCuda) throws OrtException {
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
options.setOptimizationLevel(optimizationLevel);
options.setExecutionMode(executionMode);
options.setInterOpNumThreads(interOpThreads);
options.setIntraOpNumThreads(intraOpThreads);
- addCuda(options);
- return options;
- }
-
- private void addCuda(OrtSession.SessionOptions options) {
- if (gpuDeviceNumber < 0) return;
- try {
+ if (loadCuda) {
options.addCUDA(gpuDeviceNumber);
- } catch (OrtException e) {
- if (gpuDeviceRequired) {
- throw new IllegalArgumentException("GPU device " + gpuDeviceNumber + " is required, but CUDA backend could not be initialized", e);
- }
}
+ return options;
}
public void setExecutionMode(String mode) {
@@ -75,4 +66,12 @@ public class OnnxEvaluatorOptions {
this.gpuDeviceRequired = required;
}
+ public boolean hasGpuDevice() {
+ return gpuDeviceNumber > -1;
+ }
+
+ public boolean gpuDeviceRequired() {
+ return gpuDeviceRequired;
+ }
+
}