diff options
author | Martin Polden <mpolden@mpolden.no> | 2023-01-26 09:35:59 +0100 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2023-01-26 10:12:03 +0100 |
commit | 87f21c2a817e43e2709e2c05f2453f8bd75b737c (patch) | |
tree | d0c2568fecc5e73a8647b6bede9449674cde293b /model-integration | |
parent | db1bcbfe4768e787d8794c921caea47c3e7cf58f (diff) |
Skip CUDA entirely if GPU device is optional
CUDA may fail after its library is loaded, e.g. when the session is created.
Diffstat (limited to 'model-integration')
2 files changed, 39 insertions, 24 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java index 125707c9aaa..bb40333f9b3 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java @@ -30,18 +30,8 @@ public class OnnxEvaluator { } public OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options) { - try { - if (options == null) { - options = new OnnxEvaluatorOptions(); - } - environment = OrtEnvironment.getEnvironment(); - session = environment.createSession(modelPath, options.getOptions()); - } catch (OrtException e) { - if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) { - throw new IllegalArgumentException("No such file: "+modelPath); - } - throw new RuntimeException("ONNX Runtime exception", e); - } + environment = OrtEnvironment.getEnvironment(); + session = createSession(modelPath, environment, options, true); } public Tensor evaluate(Map<String, Tensor> inputs, String output) { @@ -96,6 +86,32 @@ public class OnnxEvaluator { } } + private static OrtSession createSession(String modelPath, OrtEnvironment environment, OnnxEvaluatorOptions options, boolean tryCuda) { + if (options == null) { + options = new OnnxEvaluatorOptions(); + } + try { + return environment.createSession(modelPath, options.getOptions(tryCuda && options.hasGpuDevice())); + } catch (OrtException e) { + if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) { + throw new IllegalArgumentException("No such file: " + modelPath); + } + if (tryCuda && isCudaError(e) && !options.gpuDeviceRequired()) { + // Failed in CUDA native code, but GPU device is optional, so we can proceed without it + return createSession(modelPath, environment, options, false); + } + throw new RuntimeException("ONNX Runtime exception", e); + } + } + + private static boolean isCudaError(OrtException e) { + return switch (e.getCode()) { + case ORT_FAIL -> e.getMessage().contains("cudaError"); + case ORT_EP_FAIL -> e.getMessage().contains("Failed to find CUDA"); + default -> false; + }; + } + public static boolean isRuntimeAvailable() { return isRuntimeAvailable(""); } diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java index cdbce760d92..9d82531df02 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java @@ -29,25 +29,16 @@ public class OnnxEvaluatorOptions { gpuDeviceRequired = false; } - public OrtSession.SessionOptions getOptions() throws OrtException { + public OrtSession.SessionOptions getOptions(boolean loadCuda) throws OrtException { OrtSession.SessionOptions options = new OrtSession.SessionOptions(); options.setOptimizationLevel(optimizationLevel); options.setExecutionMode(executionMode); options.setInterOpNumThreads(interOpThreads); options.setIntraOpNumThreads(intraOpThreads); - addCuda(options); - return options; - } - - private void addCuda(OrtSession.SessionOptions options) { - if (gpuDeviceNumber < 0) return; - try { + if (loadCuda) { options.addCUDA(gpuDeviceNumber); - } catch (OrtException e) { - if (gpuDeviceRequired) { - throw new IllegalArgumentException("GPU device " + gpuDeviceNumber + " is required, but CUDA backend could not be initialized", e); - } } + return options; } public void setExecutionMode(String mode) { @@ -75,4 +66,12 @@ public class OnnxEvaluatorOptions { this.gpuDeviceRequired = required; } + public boolean hasGpuDevice() { + return gpuDeviceNumber > -1; + } + + public boolean gpuDeviceRequired() { + return gpuDeviceRequired; + } + } |